mirror of https://github.com/nomic-ai/gpt4all
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
314 lines
12 KiB
Python
314 lines
12 KiB
Python
"""
|
|
Python only API for running all GPT4All models.
|
|
"""
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
import requests
|
|
from tqdm import tqdm
|
|
|
|
from . import pyllmodel
|
|
|
|
# TODO: move to config
|
|
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
|
|
|
class GPT4All():
|
|
"""Python API for retrieving and interacting with GPT4All models.
|
|
|
|
Attribuies:
|
|
model: Pointer to underlying C model.
|
|
"""
|
|
|
|
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True):
|
|
"""
|
|
Constructor
|
|
|
|
Args:
|
|
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
|
|
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
|
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
|
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
|
|
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
|
|
Default is None.
|
|
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
|
"""
|
|
self.model = None
|
|
|
|
# Model type provided for when model is custom
|
|
if model_type:
|
|
self.model = GPT4All.get_model_from_type(model_type)
|
|
# Else get model from gpt4all model filenames
|
|
else:
|
|
self.model = GPT4All.get_model_from_name(model_name)
|
|
|
|
# Retrieve model and download if allowed
|
|
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
|
self.model.load_model(model_dest)
|
|
|
|
@staticmethod
|
|
def list_models():
|
|
"""
|
|
Fetch model list from https://gpt4all.io/models/models.json.
|
|
|
|
Returns:
|
|
Model list in JSON format.
|
|
"""
|
|
response = requests.get("https://gpt4all.io/models/models.json")
|
|
model_json = json.loads(response.content)
|
|
return model_json
|
|
|
|
@staticmethod
|
|
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True) -> str:
|
|
"""
|
|
Find model file, and if it doesn't exist, download the model.
|
|
|
|
Args:
|
|
model_name: Name of model.
|
|
model_path: Path to find model. Default is None in which case path is set to
|
|
~/.cache/gpt4all/.
|
|
allow_download: Allow API to download model from gpt4all.io. Default is True.
|
|
|
|
Returns:
|
|
Model file destination.
|
|
"""
|
|
|
|
model_filename = model_name
|
|
if ".bin" not in model_filename:
|
|
model_filename += ".bin"
|
|
|
|
# Validate download directory
|
|
if model_path == None:
|
|
model_path = DEFAULT_MODEL_DIRECTORY
|
|
if not os.path.exists(DEFAULT_MODEL_DIRECTORY):
|
|
try:
|
|
os.makedirs(DEFAULT_MODEL_DIRECTORY)
|
|
except:
|
|
raise ValueError("Failed to create model download directory at ~/.cache/gpt4all/. \
|
|
Please specify download_dir.")
|
|
else:
|
|
model_path = model_path.replace("\\", "\\\\")
|
|
|
|
if os.path.exists(model_path):
|
|
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
|
if os.path.exists(model_dest):
|
|
print("Found model file.")
|
|
return model_dest
|
|
|
|
# If model file does not exist, download
|
|
elif allow_download:
|
|
# Make sure valid model filename before attempting download
|
|
model_match = False
|
|
for item in GPT4All.list_models():
|
|
if model_filename == item["filename"]:
|
|
model_match = True
|
|
break
|
|
if not model_match:
|
|
raise ValueError(f"Model filename not in model list: {model_filename}")
|
|
return GPT4All.download_model(model_filename, model_path)
|
|
else:
|
|
raise ValueError("Failed to retrieve model")
|
|
else:
|
|
raise ValueError("Invalid model directory")
|
|
|
|
@staticmethod
|
|
def download_model(model_filename: str, model_path: str) -> str:
|
|
"""
|
|
Download model from https://gpt4all.io.
|
|
|
|
Args:
|
|
model_filename: Filename of model (with .bin extension).
|
|
model_path: Path to download model to.
|
|
|
|
Returns:
|
|
Model file destination.
|
|
"""
|
|
|
|
def get_download_url(model_filename):
|
|
return f"https://gpt4all.io/models/{model_filename}"
|
|
|
|
# Download model
|
|
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
|
download_url = get_download_url(model_filename)
|
|
|
|
# TODO: Find good way of safely removing file that got interrupted.
|
|
response = requests.get(download_url, stream=True)
|
|
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
|
block_size = 1048576 # 1 MB
|
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
|
with open(download_path, "wb") as file:
|
|
for data in response.iter_content(block_size):
|
|
progress_bar.update(len(data))
|
|
file.write(data)
|
|
progress_bar.close()
|
|
|
|
# Validate download was successful
|
|
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
|
raise RuntimeError(
|
|
"An error occurred during download. Downloaded file may not work."
|
|
)
|
|
|
|
print("Model downloaded at: " + download_path)
|
|
return download_path
|
|
|
|
def generate(self, prompt: str, **generate_kwargs) -> str:
|
|
"""
|
|
Surfaced method of running generate without accessing model object.
|
|
|
|
Args:
|
|
prompt: Raw string to be passed to model.
|
|
**generate_kwargs: Optional kwargs to pass to prompt context.
|
|
|
|
Returns:
|
|
Raw string of generated model response.
|
|
"""
|
|
return self.model.generate(prompt, **generate_kwargs)
|
|
|
|
def chat_completion(self,
|
|
messages: List[Dict],
|
|
default_prompt_header: bool = True,
|
|
default_prompt_footer: bool = True,
|
|
verbose: bool = True,
|
|
**generate_kwargs) -> str:
|
|
"""
|
|
Format list of message dictionaries into a prompt and call model
|
|
generate on prompt. Returns a response dictionary with metadata and
|
|
generated content.
|
|
|
|
Args:
|
|
messages: List of dictionaries. Each dictionary should have a "role" key
|
|
with value of "system", "assistant", or "user" and a "content" key with a
|
|
string value. Messages are organized such that "system" messages are at top of prompt,
|
|
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
|
|
"Response: {content}".
|
|
default_prompt_header: If True (default), add default prompt header after any system role messages and
|
|
before user/assistant role messages.
|
|
default_prompt_footer: If True (default), add default footer at end of prompt.
|
|
verbose: If True (default), print full prompt and generated response.
|
|
**generate_kwargs: Optional kwargs to pass to prompt context.
|
|
|
|
Returns:
|
|
Response dictionary with:
|
|
"model": name of model.
|
|
"usage": a dictionary with number of full prompt tokens, number of
|
|
generated tokens in response, and total tokens.
|
|
"choices": List of message dictionary where "content" is generated response and "role" is set
|
|
as "assistant". Right now, only one choice is returned by model.
|
|
"""
|
|
|
|
full_prompt = self._build_prompt(messages,
|
|
default_prompt_header=default_prompt_header,
|
|
default_prompt_footer=default_prompt_footer)
|
|
if verbose:
|
|
print(full_prompt)
|
|
|
|
response = self.model.generate(full_prompt, **generate_kwargs)
|
|
|
|
if verbose:
|
|
print(response)
|
|
|
|
response_dict = {
|
|
"model": self.model.model_name,
|
|
"usage": {"prompt_tokens": len(full_prompt),
|
|
"completion_tokens": len(response),
|
|
"total_tokens" : len(full_prompt) + len(response)},
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
return response_dict
|
|
|
|
@staticmethod
|
|
def _build_prompt(messages: List[Dict],
|
|
default_prompt_header=True,
|
|
default_prompt_footer=False) -> str:
|
|
# Helper method to format messages into prompt.
|
|
full_prompt = ""
|
|
|
|
for message in messages:
|
|
if message["role"] == "system":
|
|
system_message = message["content"] + "\n"
|
|
full_prompt += system_message
|
|
|
|
if default_prompt_header:
|
|
full_prompt += """### Instruction:
|
|
The prompt below is a question to answer, a task to complete, or a conversation
|
|
to respond to; decide which and write an appropriate response.
|
|
\n### Prompt: """
|
|
|
|
for message in messages:
|
|
if message["role"] == "user":
|
|
user_message = "\n" + message["content"]
|
|
full_prompt += user_message
|
|
if message["role"] == "assistant":
|
|
assistant_message = "\n### Response: " + message["content"]
|
|
full_prompt += assistant_message
|
|
|
|
if default_prompt_footer:
|
|
full_prompt += "\n### Response:"
|
|
|
|
return full_prompt
|
|
|
|
@staticmethod
|
|
def get_model_from_type(model_type: str) -> pyllmodel.LLModel:
|
|
# This needs to be updated for each new model type
|
|
# TODO: Might be worth converting model_type to enum
|
|
|
|
if model_type == "gptj":
|
|
return pyllmodel.GPTJModel()
|
|
elif model_type == "llama":
|
|
return pyllmodel.LlamaModel()
|
|
elif model_type == "mpt":
|
|
return pyllmodel.MPTModel()
|
|
else:
|
|
raise ValueError(f"No corresponding model for model_type: {model_type}")
|
|
|
|
@staticmethod
|
|
def get_model_from_name(model_name: str) -> pyllmodel.LLModel:
|
|
# This needs to be updated for each new model
|
|
|
|
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
|
|
if ".bin" not in model_name:
|
|
model_name += ".bin"
|
|
|
|
GPTJ_MODELS = [
|
|
"ggml-gpt4all-j-v1.3-groovy.bin",
|
|
"ggml-gpt4all-j-v1.2-jazzy.bin",
|
|
"ggml-gpt4all-j-v1.1-breezy.bin",
|
|
"ggml-gpt4all-j.bin"
|
|
]
|
|
|
|
LLAMA_MODELS = [
|
|
"ggml-gpt4all-l13b-snoozy.bin",
|
|
"ggml-vicuna-7b-1.1-q4_2.bin",
|
|
"ggml-vicuna-13b-1.1-q4_2.bin",
|
|
"ggml-wizardLM-7B.q4_2.bin",
|
|
"ggml-stable-vicuna-13B.q4_2.bin",
|
|
"ggml-nous-gpt4-vicuna-13b.bin"
|
|
]
|
|
|
|
MPT_MODELS = [
|
|
"ggml-mpt-7b-base.bin",
|
|
"ggml-mpt-7b-chat.bin",
|
|
"ggml-mpt-7b-instruct.bin"
|
|
]
|
|
|
|
if model_name in GPTJ_MODELS:
|
|
return pyllmodel.GPTJModel()
|
|
elif model_name in LLAMA_MODELS:
|
|
return pyllmodel.LlamaModel()
|
|
elif model_name in MPT_MODELS:
|
|
return pyllmodel.MPTModel()
|
|
else:
|
|
err_msg = f"""No corresponding model for provided filename {model_name}.
|
|
If this is a custom model, make sure to specify a valid model_type.
|
|
"""
|
|
raise ValueError(err_msg)
|