Compare commits

...

10 Commits

Author SHA1 Message Date
Jared Van Bortel 1eb7ca5865 chatllm: skip unnecessary calls to LLModel::saveState
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel fd0d26c8af chatllm: skip context switch if unload is pending
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel 8881a16398 chatllm: simplify requesting of context switch
The boolean flag doesn't do anything useful.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel 4432471c91 ChatView: do not show lower "reload" button on error
This aligns its behavior with the upper "reload" button.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel 7ef9692c7b chat: show "waiting for model" state
When switching context, this state can take a significant amount of
time. Separate it out so the user is less likely to think GPT4All is
completely stuck.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel 2c53e50222 chatllm: do not report failure on cancellation
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel b6cf5a24d4 chatllm: use std::optional for m_availableModels
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel c545f39067 chatllm: use unique_ptr for LLModelInfo::model
This fixes a memory leak if there was a model in the store on exit, e.g.
if the user loads a model and then switches to a chat associated with a
different model file without loading it.

ChatLLM::destroyStore is added to fix a heap-use-after-free caused by
the unique_ptr being freed too late. Global destructors are hard.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel 21d2392ee7 chat: initialize status at top of loadModel
There is no reason for reloadModel to show slighly different status
information from setModelInfo.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago
Jared Van Bortel 78f021f2e6 chat: make isCurrentlyLoading a chat property based on progress
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
2 weeks ago

@ -95,16 +95,6 @@ void Chat::processSystemPrompt()
emit processSystemPromptRequested();
}
bool Chat::isModelLoaded() const
{
return m_modelLoadingPercentage == 1.0f;
}
float Chat::modelLoadingPercentage() const
{
return m_modelLoadingPercentage;
}
void Chat::resetResponseState()
{
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
@ -167,9 +157,16 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
if (loadingPercentage == m_modelLoadingPercentage)
return;
bool wasLoading = isCurrentlyLoading();
bool wasLoaded = isModelLoaded();
m_modelLoadingPercentage = loadingPercentage;
emit modelLoadingPercentageChanged();
if (m_modelLoadingPercentage == 1.0f || m_modelLoadingPercentage == 0.0f)
if (isCurrentlyLoading() != wasLoading)
emit isCurrentlyLoadingChanged();
if (isModelLoaded() != wasLoaded)
emit isModelLoadedChanged();
}
@ -247,10 +244,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
if (m_modelInfo == modelInfo && isModelLoaded())
return;
m_modelLoadingPercentage = std::numeric_limits<float>::min(); // small non-zero positive value
emit isModelLoadedChanged();
m_modelLoadingError = QString();
emit modelLoadingErrorChanged();
m_modelInfo = modelInfo;
emit modelInfoChanged();
emit modelChangeRequested(modelInfo);
@ -320,9 +313,9 @@ void Chat::forceReloadModel()
void Chat::trySwitchContextOfLoadedModel()
{
m_trySwitchContextInProgress = true;
m_trySwitchContextInProgress = 1;
emit trySwitchContextInProgressChanged();
m_llmodel->setShouldTrySwitchContext(true);
m_llmodel->requestTrySwitchContext();
}
void Chat::generatedNameChanged(const QString &name)
@ -343,8 +336,10 @@ void Chat::handleRecalculating()
void Chat::handleModelLoadingError(const QString &error)
{
auto stream = qWarning().noquote() << "ERROR:" << error << "id";
stream.quote() << id();
if (!error.isEmpty()) {
auto stream = qWarning().noquote() << "ERROR:" << error << "id";
stream.quote() << id();
}
m_modelLoadingError = error;
emit modelLoadingErrorChanged();
}
@ -381,8 +376,8 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
emit modelInfoChanged();
}
void Chat::handleTrySwitchContextOfLoadedModelCompleted() {
m_trySwitchContextInProgress = false;
void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) {
m_trySwitchContextInProgress = value;
emit trySwitchContextInProgressChanged();
}

@ -17,6 +17,7 @@ class Chat : public QObject
Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged)
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(bool isCurrentlyLoading READ isCurrentlyLoading NOTIFY isCurrentlyLoadingChanged)
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
@ -30,7 +31,8 @@ class Chat : public QObject
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged);
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
Q_PROPERTY(bool trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
// 0=no, 1=waiting, 2=working
Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
@ -63,8 +65,9 @@ public:
Q_INVOKABLE void reset();
Q_INVOKABLE void processSystemPrompt();
Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE float modelLoadingPercentage() const;
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
Q_INVOKABLE void prompt(const QString &prompt);
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void stopGenerating();
@ -106,7 +109,7 @@ public:
QString device() const { return m_device; }
QString fallbackReason() const { return m_fallbackReason; }
bool trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }
int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }
public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt);
@ -116,6 +119,7 @@ Q_SIGNALS:
void nameChanged();
void chatModelChanged();
void isModelLoadedChanged();
void isCurrentlyLoadingChanged();
void modelLoadingPercentageChanged();
void modelLoadingWarning(const QString &warning);
void responseChanged();
@ -139,7 +143,6 @@ Q_SIGNALS:
void deviceChanged();
void fallbackReasonChanged();
void collectionModelChanged();
void trySwitchContextOfLoadedModelCompleted(bool);
void trySwitchContextInProgressChanged();
private Q_SLOTS:
@ -155,7 +158,7 @@ private Q_SLOTS:
void handleFallbackReasonChanged(const QString &device);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleTrySwitchContextOfLoadedModelCompleted();
void handleTrySwitchContextOfLoadedModelCompleted(int value);
private:
QString m_id;
@ -180,7 +183,8 @@ private:
float m_modelLoadingPercentage = 0.0f;
LocalDocsCollectionsModel *m_collectionModel;
bool m_firstResponse = true;
bool m_trySwitchContextInProgress = false;
int m_trySwitchContextInProgress = 0;
bool m_isCurrentlyLoading = false;
};
#endif // CHAT_H

@ -195,7 +195,11 @@ public:
int count() const { return m_chats.size(); }
// stop ChatLLM threads for clean shutdown
void destroyChats() { for (auto *chat: m_chats) { chat->destroy(); } }
void destroyChats()
{
for (auto *chat: m_chats) { chat->destroy(); }
ChatLLM::destroyStore();
}
void removeChatFile(Chat *chat) const;
Q_INVOKABLE void saveChats();

@ -30,16 +30,17 @@ public:
static LLModelStore *globalInstance();
LLModelInfo acquireModel(); // will block until llmodel is ready
void releaseModel(const LLModelInfo &info); // must be called when you are done
void releaseModel(LLModelInfo &&info); // must be called when you are done
void destroy();
private:
LLModelStore()
{
// seed with empty model
m_availableModels.append(LLModelInfo());
m_availableModel = LLModelInfo();
}
~LLModelStore() {}
QVector<LLModelInfo> m_availableModels;
std::optional<LLModelInfo> m_availableModel;
QMutex m_mutex;
QWaitCondition m_condition;
friend class MyLLModelStore;
@ -55,19 +56,27 @@ LLModelStore *LLModelStore::globalInstance()
LLModelInfo LLModelStore::acquireModel()
{
QMutexLocker locker(&m_mutex);
while (m_availableModels.isEmpty())
while (!m_availableModel)
m_condition.wait(locker.mutex());
return m_availableModels.takeFirst();
auto first = std::move(*m_availableModel);
m_availableModel.reset();
return first;
}
void LLModelStore::releaseModel(const LLModelInfo &info)
void LLModelStore::releaseModel(LLModelInfo &&info)
{
QMutexLocker locker(&m_mutex);
m_availableModels.append(info);
Q_ASSERT(m_availableModels.count() < 2);
Q_ASSERT(!m_availableModel);
m_availableModel = std::move(info);
m_condition.wakeAll();
}
void LLModelStore::destroy()
{
QMutexLocker locker(&m_mutex);
m_availableModel.reset();
}
ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr}
, m_promptResponseTokens(0)
@ -76,7 +85,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false)
, m_shouldTrySwitchContext(false)
, m_stopGenerating(false)
, m_timer(nullptr)
, m_isServer(isServer)
@ -88,7 +96,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel,
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
@ -108,7 +116,8 @@ ChatLLM::~ChatLLM()
destroy();
}
void ChatLLM::destroy() {
void ChatLLM::destroy()
{
m_stopGenerating = true;
m_llmThread.quit();
m_llmThread.wait();
@ -116,11 +125,15 @@ void ChatLLM::destroy() {
// The only time we should have a model loaded here is on shutdown
// as we explicitly unload the model in all other circumstances
if (isModelLoaded()) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
}
}
void ChatLLM::destroyStore()
{
LLModelStore::globalInstance()->destroy();
}
void ChatLLM::handleThreadStarted()
{
m_timer = new TokenTimer(this);
@ -161,7 +174,7 @@ bool ChatLLM::loadDefaultModel()
return loadModel(defaultModel);
}
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{
// We're trying to see if the store already has the model fully loaded that we wish to use
// and if so we just acquire it from the store and switch the context and return true. If the
@ -169,10 +182,11 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
// If we're already loaded or a server or we're reloading to change the variant/device or the
// modelInfo is empty, then this should fail
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
if (
isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded
) {
emit trySwitchContextOfLoadedModelCompleted(0);
return;
}
QString filePath = modelInfo.dirpath + modelInfo.filename();
@ -180,33 +194,28 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
// The store gave us no already loaded model, the wrong type of model, then give it back to the
// store and fail
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo || !m_shouldBeLoaded) {
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
emit trySwitchContextOfLoadedModelCompleted(0);
return;
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
// We should be loaded and now we are
m_shouldBeLoaded = true;
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(2);
// Restore, signal and process
restoreState();
emit modelLoadingPercentageChanged(1.0f);
emit trySwitchContextOfLoadedModelCompleted(true);
emit trySwitchContextOfLoadedModelCompleted(0);
processSystemPrompt();
return true;
}
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
@ -223,6 +232,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (isModelLoaded() && this->modelInfo() == modelInfo)
return true;
// reset status
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
emit modelLoadingError("");
emit reportFallbackReason("");
emit reportDevice("");
m_pristineLoadedState = false;
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);
@ -231,28 +247,25 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (alreadyAcquired) {
resetContext();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
m_llModelInfo.model.reset();
} else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading
if (!m_shouldBeLoaded) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
emit modelLoadingPercentageChanged(0.0f);
return false;
}
@ -260,7 +273,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// Check if the store just gave us exactly the model we were looking for
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
restoreState();
emit modelLoadingPercentageChanged(1.0f);
@ -274,10 +287,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
} else {
// Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
}
}
@ -307,7 +319,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
model->setModelName(modelName);
model->setRequestURL(modelInfo.url());
model->setAPIKey(apiKey);
m_llModelInfo.model = model;
m_llModelInfo.model.reset(model);
} else {
QElapsedTimer modelLoadTimer;
modelLoadTimer.start();
@ -322,9 +334,10 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
buildVariant = "metal";
#endif
QString constructError;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
try {
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
auto *model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
m_llModelInfo.model.reset(model);
} catch (const LLModel::MissingImplementationError &e) {
modelLoadProps.insert("error", "missing_model_impl");
constructError = e.what();
@ -354,8 +367,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
return m_shouldBeLoaded;
});
emit reportFallbackReason(""); // no fallback yet
auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) {
float memGB = dev->heapSize / float(1024 * 1024 * 1024);
return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
@ -406,6 +417,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
if (!m_shouldBeLoaded) {
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
if (actualDevice == "CPU") {
// we asked llama.cpp to use the CPU
} else if (!success) {
@ -414,6 +435,15 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed");
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
if (!m_shouldBeLoaded) {
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
} else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate
@ -424,10 +454,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
}
if (!success) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
modelLoadProps.insert("error", "loadmodel_failed");
@ -437,10 +466,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
case 'G': m_llModelType = LLModelType::GPTJ_; break;
default:
{
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename()));
}
@ -450,13 +478,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
}
} else {
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError));
}
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
restoreState();
#if defined(DEBUG)
@ -470,7 +498,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
} else {
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename()));
}
@ -479,7 +507,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
setModelInfo(modelInfo);
processSystemPrompt();
}
return m_llModelInfo.model;
return bool(m_llModelInfo.model);
}
bool ChatLLM::isModelLoaded() const
@ -699,22 +727,23 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
emit responseChanged(QString::fromStdString(m_response));
}
emit responseStopped(elapsed);
m_pristineLoadedState = false;
return true;
}
void ChatLLM::setShouldBeLoaded(bool b)
{
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model;
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
#endif
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged();
}
void ChatLLM::setShouldTrySwitchContext(bool b)
void ChatLLM::requestTrySwitchContext()
{
m_shouldTrySwitchContext = b; // atomic
emit shouldTrySwitchContextChanged();
m_shouldBeLoaded = true; // atomic
emit trySwitchContextRequested(modelInfo());
}
void ChatLLM::handleShouldBeLoadedChanged()
@ -725,12 +754,6 @@ void ChatLLM::handleShouldBeLoadedChanged()
unloadModel();
}
void ChatLLM::handleShouldTrySwitchContextChanged()
{
if (m_shouldTrySwitchContext)
trySwitchContextOfLoadedModel(modelInfo());
}
void ChatLLM::unloadModel()
{
if (!isModelLoaded() || m_isServer)
@ -745,17 +768,16 @@ void ChatLLM::unloadModel()
saveState();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
if (m_forceUnloadModel) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
m_forceUnloadModel = false;
}
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_pristineLoadedState = false;
}
void ChatLLM::reloadModel()
@ -767,7 +789,7 @@ void ChatLLM::reloadModel()
return;
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
const ModelInfo m = modelInfo();
if (m.name().isEmpty())
@ -794,6 +816,7 @@ void ChatLLM::generateName()
m_nameResponse = trimmed;
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
}
m_pristineLoadedState = false;
}
void ChatLLM::handleChatIdChanged(const QString &id)
@ -933,7 +956,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = !deserializeKV || discardKV;
if (!deserializeKV || discardKV) {
m_restoreStateFromText = true;
m_pristineLoadedState = true;
}
if (!deserializeKV) {
#if defined(DEBUG)
@ -997,14 +1023,14 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
void ChatLLM::saveState()
{
if (!isModelLoaded())
if (!isModelLoaded() || m_pristineLoadedState)
return;
if (m_llModelType == LLModelType::API_) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
stream << chatAPI->context();
return;
}
@ -1025,7 +1051,7 @@ void ChatLLM::restoreState()
if (m_llModelType == LLModelType::API_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
QList<QString> context;
stream >> context;
chatAPI->setContext(context);
@ -1044,13 +1070,18 @@ void ChatLLM::restoreState()
if (m_llModelInfo.model->stateSize() == m_state.size()) {
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
} else {
qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size();
m_restoreStateFromText = true;
}
m_state.clear();
m_state.squeeze();
// free local state copy unless unload is pending
if (m_shouldBeLoaded) {
m_state.clear();
m_state.squeeze();
m_pristineLoadedState = false;
}
}
void ChatLLM::processSystemPrompt()
@ -1104,6 +1135,7 @@ void ChatLLM::processSystemPrompt()
#endif
m_processedSystemPrompt = m_stopGenerating == false;
m_pristineLoadedState = false;
}
void ChatLLM::processRestoreStateFromText()
@ -1162,4 +1194,6 @@ void ChatLLM::processRestoreStateFromText()
m_isRecalc = false;
emit recalcChanged();
m_pristineLoadedState = false;
}

@ -5,6 +5,8 @@
#include <QThread>
#include <QFileInfo>
#include <memory>
#include "database.h"
#include "modellist.h"
#include "../gpt4all-backend/llmodel.h"
@ -16,7 +18,7 @@ enum LLModelType {
};
struct LLModelInfo {
LLModel *model = nullptr;
std::unique_ptr<LLModel> model;
QFileInfo fileInfo;
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
// must be able to serialize the information even if it is in the unloaded state
@ -72,6 +74,7 @@ public:
virtual ~ChatLLM();
void destroy();
static void destroyStore();
bool isModelLoaded() const;
void regenerateResponse();
void resetResponse();
@ -81,7 +84,7 @@ public:
bool shouldBeLoaded() const { return m_shouldBeLoaded; }
void setShouldBeLoaded(bool b);
void setShouldTrySwitchContext(bool b);
void requestTrySwitchContext();
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
@ -101,7 +104,7 @@ public:
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
bool loadDefaultModel();
bool trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
bool loadModel(const ModelInfo &modelInfo);
void modelChangeRequested(const ModelInfo &modelInfo);
void unloadModel();
@ -109,7 +112,6 @@ public Q_SLOTS:
void generateName();
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleShouldTrySwitchContextChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
@ -128,8 +130,8 @@ Q_SIGNALS:
void stateChanged();
void threadStarted();
void shouldBeLoadedChanged();
void shouldTrySwitchContextChanged();
void trySwitchContextOfLoadedModelCompleted(bool);
void trySwitchContextRequested(const ModelInfo &modelInfo);
void trySwitchContextOfLoadedModelCompleted(int value);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
@ -172,7 +174,6 @@ private:
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_shouldTrySwitchContext;
std::atomic<bool> m_isRecalc;
std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion;
@ -181,6 +182,10 @@ private:
bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
// m_pristineLoadedState is set if saveSate is unnecessary, either because:
// - an unload was queued during LLModel::restoreState()
// - the chat will be restored from text and hasn't been interacted with yet
bool m_pristineLoadedState = false;
QVector<QPair<QString, QString>> m_stateFromText;
};

@ -122,8 +122,6 @@ Rectangle {
return ModelList.modelInfo(currentChat.modelInfo.id).name;
}
property bool isCurrentlyLoading: false
PopupDialog {
id: errorCompatHardware
anchors.centerIn: parent
@ -339,25 +337,17 @@ Rectangle {
width: window.width >= 750 ? implicitWidth : implicitWidth - (750 - window.width)
enabled: !currentChat.isServer
&& !currentChat.trySwitchContextInProgress
&& !window.isCurrentlyLoading
&& !currentChat.isCurrentlyLoading
model: ModelList.installedModels
valueRole: "id"
textRole: "name"
function changeModel(index) {
window.isCurrentlyLoading = true;
currentChat.stopGenerating()
currentChat.reset();
currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index))
}
Connections {
target: currentChat
function onModelLoadingPercentageChanged() {
window.isCurrentlyLoading = currentChat.modelLoadingPercentage !== 0.0
&& currentChat.modelLoadingPercentage !== 1.0;
}
}
Connections {
target: switchModelDialog
function onAccepted() {
@ -374,7 +364,7 @@ Rectangle {
}
contentItem: Item {
Rectangle {
visible: window.isCurrentlyLoading
visible: currentChat.isCurrentlyLoading
anchors.bottom: parent.bottom
width: modelProgress.visualPosition * parent.width
height: 10
@ -396,13 +386,15 @@ Rectangle {
text: {
if (currentChat.modelLoadingError !== "")
return qsTr("Model loading error...")
if (currentChat.trySwitchContextInProgress)
if (currentChat.trySwitchContextInProgress == 1)
return qsTr("Waiting for model...")
if (currentChat.trySwitchContextInProgress == 2)
return qsTr("Switching context...")
if (currentModelName() === "")
return qsTr("Choose a model...")
if (currentChat.modelLoadingPercentage === 0.0)
return qsTr("Reload \u00B7 ") + currentModelName()
if (window.isCurrentlyLoading)
if (currentChat.isCurrentlyLoading)
return qsTr("Loading \u00B7 ") + currentModelName()
return currentModelName()
}
@ -446,7 +438,7 @@ Rectangle {
MyMiniButton {
id: ejectButton
visible: currentChat.isModelLoaded && !window.isCurrentlyLoading
visible: currentChat.isModelLoaded && !currentChat.isCurrentlyLoading
z: 500
anchors.right: parent.right
anchors.rightMargin: 50
@ -465,7 +457,7 @@ Rectangle {
id: reloadButton
visible: currentChat.modelLoadingError === ""
&& !currentChat.trySwitchContextInProgress
&& !window.isCurrentlyLoading
&& !currentChat.isCurrentlyLoading
&& (currentChat.isModelLoaded || currentModelName() !== "")
z: 500
anchors.right: ejectButton.visible ? ejectButton.left : parent.right
@ -1334,8 +1326,9 @@ Rectangle {
textColor: theme.textColor
visible: !currentChat.isServer
&& !currentChat.isModelLoaded
&& currentChat.modelLoadingError === ""
&& !currentChat.trySwitchContextInProgress
&& !window.isCurrentlyLoading
&& !currentChat.isCurrentlyLoading
&& currentModelName() !== ""
Image {

Loading…
Cancel
Save