Refactor get_repo() and load_files() functions to use Repo() without root_dir.Refactored `load_files` added a delay when creating vector store.

pull/19/head
rsaryev 9 months ago
parent 3d3e2dabd5
commit b978a76402
No known key found for this signature in database

3
.gitignore vendored

@ -3,4 +3,5 @@
/.vscode/
/.venv/
/talk_codebase/__pycache__/
.DS_Store
.DS_Store
/vector_store/

@ -14,6 +14,7 @@ talk-codebase is still under development and is recommended for educational purp
## Installation
Requirement Python 3.8.1 or higher
Your project must be in a git repository
```bash
pip install talk-codebase

@ -1,6 +1,6 @@
[tool.poetry]
name = "talk-codebase"
version = "0.1.46"
version = "0.1.47"
description = "talk-codebase is a powerful tool for querying and analyzing codebases."
authors = ["Saryev Rustam <rustam1997@gmail.com>"]
readme = "README.md"

@ -6,6 +6,7 @@ from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, confi
remove_model_type, remove_model_name_local
from talk_codebase.consts import DEFAULT_CONFIG
from talk_codebase.llm import factory_llm
from talk_codebase.utils import get_repo
def check_python_version():
@ -44,10 +45,14 @@ def chat_loop(llm):
llm.send_query(query)
def chat(root_dir=None):
def chat():
configure(False)
config = get_config()
llm = factory_llm(root_dir, config)
repo = get_repo()
if not repo:
print("🤖 Git repository not found")
sys.exit(1)
llm = factory_llm(repo.working_dir, config)
chat_loop(llm)

@ -152,7 +152,6 @@ def configure_model_type(config):
).ask()
config["model_type"] = model_type
save_config(config)
print("🤖 Model type saved!")
CONFIGURE_STEPS = [

@ -1,4 +1,5 @@
import os
import time
from typing import Optional
import gpt4all
@ -40,7 +41,7 @@ class BaseLLM:
if new_db is not None:
return new_db.as_retriever(search_kwargs={"k": k})
docs = load_files(root_dir)
docs = load_files()
if len(docs) == 0:
print("✘ No documents found")
exit(0)
@ -60,9 +61,13 @@ class BaseLLM:
exit(0)
spinners = Halo(text=f"Creating vector store", spinner='dots').start()
db = FAISS.from_documents(texts, embeddings)
db.add_documents(texts)
db.save_local(index_path)
db = FAISS.from_documents([texts[0]], embeddings)
for i, text in enumerate(texts[1:]):
spinners.text = f"Creating vector store ({i + 1}/{len(texts)})"
db.add_documents([text])
db.save_local(index_path)
time.sleep(1.5)
spinners.succeed(f"Created vector store")
return db.as_retriever(search_kwargs={"k": k})
@ -93,7 +98,8 @@ class LocalLLM(BaseLLM):
model_n_ctx = int(self.config.get("max_tokens"))
model_n_batch = int(self.config.get("n_batch"))
callbacks = CallbackManager([StreamStdOut()])
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False)
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks,
verbose=False)
llm.client.verbose = False
return llm

@ -1,6 +1,3 @@
import glob
import multiprocessing
import os
import sys
import tiktoken
@ -11,23 +8,13 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES
def get_repo(root_dir):
def get_repo():
try:
return Repo(root_dir)
return Repo()
except:
return None
def is_ignored(path, root_dir):
repo = get_repo(root_dir)
if repo is None:
return False
if not os.path.exists(path):
return False
ignored = repo.ignored(path)
return len(ignored) > 0
class StreamStdOut(StreamingStdOutCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs) -> None:
sys.stdout.write(token)
@ -41,26 +28,24 @@ class StreamStdOut(StreamingStdOutCallbackHandler):
sys.stdout.flush()
def load_files(root_dir):
num_cpus = multiprocessing.cpu_count()
with multiprocessing.Pool(num_cpus) as pool:
futures = []
for file_path in glob.glob(os.path.join(root_dir, '**/*'), recursive=True):
if is_ignored(file_path, root_dir):
continue
if any(
file_path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES):
continue
for ext in LOADER_MAPPING:
if file_path.endswith(ext):
print('\r' + f'📂 Loading files: {file_path}')
args = LOADER_MAPPING[ext]['args']
loader = LOADER_MAPPING[ext]['loader'](file_path, *args)
futures.append(pool.apply_async(loader.load))
docs = []
for future in futures:
docs.extend(future.get())
return docs
def load_files():
repo = get_repo()
if repo is None:
return []
files = []
tree = repo.tree()
for blob in tree.traverse():
path = blob.path
if any(
path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES):
continue
for ext in LOADER_MAPPING:
if path.endswith(ext):
print('\r' + f'📂 Loading files: {path}')
args = LOADER_MAPPING[ext]['args']
loader = LOADER_MAPPING[ext]['loader'](path, *args)
files.extend(loader.load())
return files
def calculate_cost(texts, model_name):

Loading…
Cancel
Save