typescript!: chatSessions, fixes, tokenStreams (#2045)

Signed-off-by: jacob <jacoobes@sern.dev>
Signed-off-by: limez <limez@protonmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: limez <limez@protonmail.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
pull/2175/head
Jacob Nguyen 2 months ago committed by GitHub
parent 6c8a44f6c4
commit 55f3b056b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -11,37 +11,116 @@ pnpm install gpt4all@latest
```
The original [GPT4All typescript bindings](https://github.com/nomic-ai/gpt4all-ts) are now out of date.
## Contents
* New bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
* The nodejs api has made strides to mirror the python api. It is not 100% mirrored, but many pieces of the api resemble its python counterpart.
* Everything should work out the box.
* See [API Reference](#api-reference)
* See [Examples](#api-example)
* See [Developing](#develop)
* GPT4ALL nodejs bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
## Api Example
### Chat Completion
```js
import { createCompletion, loadModel } from '../src/gpt4all.js'
import { LLModel, createCompletion, DEFAULT_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY, loadModel } from '../src/gpt4all.js'
const model = await loadModel( 'mistral-7b-openorca.gguf2.Q4_0.gguf', { verbose: true, device: 'gpu' });
const model = await loadModel('mistral-7b-openorca.Q4_0.gguf', { verbose: true });
const completion1 = await createCompletion(model, 'What is 1 + 1?', { verbose: true, })
console.log(completion1.message)
const response = await createCompletion(model, [
{ role : 'system', content: 'You are meant to be annoying and unhelpful.' },
{ role : 'user', content: 'What is 1 + 1?' }
]);
const completion2 = await createCompletion(model, 'And if we add two?', { verbose: true })
console.log(completion2.message)
model.dispose()
```
### Embedding
```js
import { createEmbedding, loadModel } from '../src/gpt4all.js'
import { loadModel, createEmbedding } from '../src/gpt4all.js'
const embedder = await loadModel("all-MiniLM-L6-v2-f16.gguf", { verbose: true, type: 'embedding'})
console.log(createEmbedding(embedder, "Maybe Minecraft was the friends we made along the way"));
```
### Chat Sessions
```js
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true,
device: "gpu",
});
const model = await loadModel('ggml-all-MiniLM-L6-v2-f16', { verbose: true });
const chat = await model.createChatSession();
await createCompletion(
chat,
"Why are bananas rather blue than bread at night sometimes?",
{
verbose: true,
}
);
await createCompletion(chat, "Are you sure?", { verbose: true, });
const fltArray = createEmbedding(model, "Pain is inevitable, suffering optional");
```
### Streaming responses
```js
import gpt from "../src/gpt4all.js";
const model = await gpt.loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("### Stream:");
const stream = gpt.createCompletionStream(model, "How are you?");
stream.tokens.on("data", (data) => {
process.stdout.write(data);
});
//wait till stream finishes. We cannot continue until this one is done.
await stream.result;
process.stdout.write("\n");
process.stdout.write("### Stream with pipe:");
const stream2 = gpt.createCompletionStream(
model,
"Please say something nice about node streams."
);
stream2.tokens.pipe(process.stdout);
await stream2.result;
process.stdout.write("\n");
console.log("done");
model.dispose();
```
### Async Generators
```js
import gpt from "../src/gpt4all.js";
const model = await gpt.loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("### Generator:");
const gen = gpt.createCompletionGenerator(model, "Redstone in Minecraft is Turing Complete. Let that sink in. (let it in!)");
for await (const chunk of gen) {
process.stdout.write(chunk);
}
process.stdout.write("\n");
model.dispose();
```
## Develop
### Build Instructions
* binding.gyp is compile config
@ -131,21 +210,27 @@ yarn test
* why your model may be spewing bull 💩
* The downloaded model is broken (just reinstall or download from official site)
* That's it so far
* Your model is hanging after a call to generate tokens.
* Is `nPast` set too high? This may cause your model to hang (03/16/2024), Linux Mint, Ubuntu 22.04
* Your GPU usage is still high after node.js exits.
* Make sure to call `model.dispose()`!!!
### Roadmap
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
This package has been stabilizing over time development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[x] createChatSession ( the python equivalent to create\_chat\_session )
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
* \[ ] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] generateTokens is the new name for this^
* \[x] proper unit testing (integrate with circle ci)
* \[x] publish to npm under alpha tag `gpt4all@alpha`
* \[x] have more people test on other platforms (mac tester needed)
* \[x] switch to new pluggable backend
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
### API Reference
@ -153,144 +238,200 @@ This package is in active development, and breaking changes may happen until the
##### Table of Contents
* [ModelFile](#modelfile)
* [gptj](#gptj)
* [llama](#llama)
* [mpt](#mpt)
* [replit](#replit)
* [type](#type)
* [TokenCallback](#tokencallback)
* [ChatSessionOptions](#chatsessionoptions)
* [systemPrompt](#systemprompt)
* [messages](#messages)
* [initialize](#initialize)
* [Parameters](#parameters)
* [generate](#generate)
* [Parameters](#parameters-1)
* [InferenceModel](#inferencemodel)
* [createChatSession](#createchatsession)
* [Parameters](#parameters-2)
* [generate](#generate-1)
* [Parameters](#parameters-3)
* [dispose](#dispose)
* [EmbeddingModel](#embeddingmodel)
* [dispose](#dispose-1)
* [InferenceResult](#inferenceresult)
* [LLModel](#llmodel)
* [constructor](#constructor)
* [Parameters](#parameters)
* [Parameters](#parameters-4)
* [type](#type-1)
* [name](#name)
* [stateSize](#statesize)
* [threadCount](#threadcount)
* [setThreadCount](#setthreadcount)
* [Parameters](#parameters-1)
* [raw\_prompt](#raw_prompt)
* [Parameters](#parameters-2)
* [Parameters](#parameters-5)
* [infer](#infer)
* [Parameters](#parameters-6)
* [embed](#embed)
* [Parameters](#parameters-3)
* [Parameters](#parameters-7)
* [isModelLoaded](#ismodelloaded)
* [setLibraryPath](#setlibrarypath)
* [Parameters](#parameters-4)
* [Parameters](#parameters-8)
* [getLibraryPath](#getlibrarypath)
* [initGpuByString](#initgpubystring)
* [Parameters](#parameters-5)
* [Parameters](#parameters-9)
* [hasGpuDevice](#hasgpudevice)
* [listGpu](#listgpu)
* [Parameters](#parameters-6)
* [Parameters](#parameters-10)
* [dispose](#dispose-2)
* [GpuDevice](#gpudevice)
* [type](#type-2)
* [LoadModelOptions](#loadmodeloptions)
* [modelPath](#modelpath)
* [librariesPath](#librariespath)
* [modelConfigFile](#modelconfigfile)
* [allowDownload](#allowdownload)
* [verbose](#verbose)
* [device](#device)
* [nCtx](#nctx)
* [ngl](#ngl)
* [loadModel](#loadmodel)
* [Parameters](#parameters-7)
* [Parameters](#parameters-11)
* [InferenceProvider](#inferenceprovider)
* [createCompletion](#createcompletion)
* [Parameters](#parameters-8)
* [Parameters](#parameters-12)
* [createCompletionStream](#createcompletionstream)
* [Parameters](#parameters-13)
* [createCompletionGenerator](#createcompletiongenerator)
* [Parameters](#parameters-14)
* [createEmbedding](#createembedding)
* [Parameters](#parameters-9)
* [Parameters](#parameters-15)
* [CompletionOptions](#completionoptions)
* [verbose](#verbose)
* [systemPromptTemplate](#systemprompttemplate)
* [promptTemplate](#prompttemplate)
* [promptHeader](#promptheader)
* [promptFooter](#promptfooter)
* [PromptMessage](#promptmessage)
* [verbose](#verbose-1)
* [onToken](#ontoken)
* [Message](#message)
* [role](#role)
* [content](#content)
* [prompt\_tokens](#prompt_tokens)
* [completion\_tokens](#completion_tokens)
* [total\_tokens](#total_tokens)
* [n\_past\_tokens](#n_past_tokens)
* [CompletionReturn](#completionreturn)
* [model](#model)
* [usage](#usage)
* [choices](#choices)
* [CompletionChoice](#completionchoice)
* [message](#message)
* [message](#message-1)
* [CompletionStreamReturn](#completionstreamreturn)
* [LLModelPromptContext](#llmodelpromptcontext)
* [logitsSize](#logitssize)
* [tokensSize](#tokenssize)
* [nPast](#npast)
* [nCtx](#nctx)
* [nPredict](#npredict)
* [promptTemplate](#prompttemplate)
* [nCtx](#nctx-1)
* [topK](#topk)
* [topP](#topp)
* [temp](#temp)
* [minP](#minp)
* [temperature](#temperature)
* [nBatch](#nbatch)
* [repeatPenalty](#repeatpenalty)
* [repeatLastN](#repeatlastn)
* [contextErase](#contexterase)
* [generateTokens](#generatetokens)
* [Parameters](#parameters-10)
* [DEFAULT\_DIRECTORY](#default_directory)
* [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory)
* [DEFAULT\_MODEL\_CONFIG](#default_model_config)
* [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context)
* [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url)
* [downloadModel](#downloadmodel)
* [Parameters](#parameters-11)
* [Parameters](#parameters-16)
* [Examples](#examples)
* [DownloadModelOptions](#downloadmodeloptions)
* [modelPath](#modelpath)
* [verbose](#verbose-1)
* [modelPath](#modelpath-1)
* [verbose](#verbose-2)
* [url](#url)
* [md5sum](#md5sum)
* [DownloadController](#downloadcontroller)
* [cancel](#cancel)
* [promise](#promise)
#### ModelFile
#### type
Full list of models available
DEPRECATED!! These model names are outdated and this type will not be maintained, please use a string literal instead
Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user.
##### gptj
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
List of GPT-J Models
#### TokenCallback
Type: (`"ggml-gpt4all-j-v1.3-groovy.bin"` | `"ggml-gpt4all-j-v1.2-jazzy.bin"` | `"ggml-gpt4all-j-v1.1-breezy.bin"` | `"ggml-gpt4all-j.bin"`)
Callback for controlling token generation. Return false to stop token generation.
##### llama
Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
List Llama Models
#### ChatSessionOptions
Type: (`"ggml-gpt4all-l13b-snoozy.bin"` | `"ggml-vicuna-7b-1.1-q4_2.bin"` | `"ggml-vicuna-13b-1.1-q4_2.bin"` | `"ggml-wizardLM-7B.q4_2.bin"` | `"ggml-stable-vicuna-13B.q4_2.bin"` | `"ggml-nous-gpt4-vicuna-13b.bin"` | `"ggml-v3-13b-hermes-q5_1.bin"`)
**Extends Partial\<LLModelPromptContext>**
##### mpt
Options for the chat session.
List of MPT Models
##### systemPrompt
Type: (`"ggml-mpt-7b-base.bin"` | `"ggml-mpt-7b-chat.bin"` | `"ggml-mpt-7b-instruct.bin"`)
System prompt to ingest on initialization.
##### replit
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
List of Replit Models
##### messages
Type: `"ggml-replit-code-v1-3b.bin"`
Messages to ingest on initialization.
#### type
Type: [Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[Message](#message)>
Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user.
#### initialize
Type: ModelType
Ingests system prompt and initial messages.
Sets this chat session as the active chat session of the model.
#### TokenCallback
##### Parameters
Callback for controlling token generation
* `options` **[ChatSessionOptions](#chatsessionoptions)** The options for the chat session.
Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)\<void>**&#x20;
#### generate
Prompts the model in chat-session context.
##### Parameters
* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `options` **[CompletionOptions](#completionoptions)?** Prompt context and other options.
* `callback` **[TokenCallback](#tokencallback)?** Token generation callback.
<!---->
* Throws **[Error](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Error)** If the chat session is not the active chat session of the model.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[CompletionReturn](#completionreturn)>** The model's response to the prompt.
#### InferenceModel
InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
##### createChatSession
Create a chat session with the model.
###### Parameters
* `options` **[ChatSessionOptions](#chatsessionoptions)?** The options for the chat session.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)\<ChatSession>** The chat session.
##### generate
Prompts the model with a given input and optional parameters.
###### Parameters
* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)**&#x20;
* `options` **[CompletionOptions](#completionoptions)?** Prompt context and other options.
* `callback` **[TokenCallback](#tokencallback)?** Token generation callback.
* `input` The prompt input.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[CompletionReturn](#completionreturn)>** The model's response to the prompt.
##### dispose
delete and cleanup the native model
@ -307,6 +448,10 @@ delete and cleanup the native model
Returns **void**&#x20;
#### InferenceResult
Shape of LLModel's inference result.
#### LLModel
LLModel class representing a language model.
@ -326,9 +471,9 @@ Initialize a new LLModel.
##### type
either 'gpt', mpt', or 'llama' or undefined
undefined or user supplied
Returns **(ModelType | [undefined](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/undefined))**&#x20;
Returns **([string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) | [undefined](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/undefined))**&#x20;
##### name
@ -360,7 +505,7 @@ Set the number of threads used for model inference.
Returns **void**&#x20;
##### raw\_prompt
##### infer
Prompt the model with a given input and optional parameters.
This is the raw output from model.
@ -368,23 +513,20 @@ Use the prompt function exported for a value
###### Parameters
* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `promptContext` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
* `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The result of the model prompt.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[InferenceResult](#inferenceresult)>** The result of the model prompt.
##### embed
Embed text with the model. Keep in mind that
not all models can embed text, (only bert can embed as of 07/16/2023 (mm/dd/yyyy))
Use the prompt function exported for a value
###### Parameters
* `text` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)**&#x20;
* `q` The prompt input.
* `params` Optional parameters for the prompt context.
* `text` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
Returns **[Float32Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Float32Array)** The result of the model prompt.
@ -462,6 +604,62 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
Options that configure a model's behavior.
##### modelPath
Where to look for model files.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### librariesPath
Where to look for the backend libraries.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### modelConfigFile
The path to the model configuration file, useful for offline usage or custom model configurations.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### allowDownload
Whether to allow downloading the model if it is not present at the specified path.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### verbose
Enable verbose logging.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### device
The processing unit on which the model will run. It can be set to
* "cpu": Model will run on the central processing unit.
* "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
* "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
* "gpu name": Model will run on the GPU that matches the name if it's available.
Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All
instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the
model.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### nCtx
The Maximum window size of this model
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### ngl
Number of gpu layers needed
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### loadModel
Loads a machine learning model with the specified name. The defacto way to create a model.
@ -474,18 +672,46 @@ By default this will download a model from the official GPT4ALL website, if a mo
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<([InferenceModel](#inferencemodel) | [EmbeddingModel](#embeddingmodel))>** A promise that resolves to an instance of the loaded LLModel.
#### InferenceProvider
Interface for inference, implemented by InferenceModel and ChatSession.
#### createCompletion
The nodejs equivalent to python binding's chat\_completion
##### Parameters
* `model` **[InferenceModel](#inferencemodel)** The language model object.
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation.
* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session
* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
Returns **[CompletionReturn](#completionreturn)** The completion result.
#### createCompletionStream
Streaming variant of createCompletion, returns a stream of tokens and a promise that resolves to the completion result.
##### Parameters
* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session
* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message.
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
Returns **[CompletionStreamReturn](#completionstreamreturn)** An object of token stream and the completion result promise.
#### createCompletionGenerator
Creates an async generator of tokens
##### Parameters
* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session
* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message.
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
#### createEmbedding
The nodejs moral equivalent to python binding's Embed4All().embed()
@ -510,34 +736,15 @@ Indicates if verbose logging is enabled.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### systemPromptTemplate
##### onToken
Template for the system message. Will be put before the conversation with %1 being replaced by all system messages.
Note that if this is not defined, system messages will not be included in the prompt.
Callback for controlling token generation. Return false to stop processing.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
Type: [TokenCallback](#tokencallback)
##### promptTemplate
Template for user messages, with %1 being replaced by the message.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### promptHeader
The initial instruction for the model, on top of the prompt
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### promptFooter
The last instruction for the model, appended to the end of the prompt.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
#### Message
#### PromptMessage
A message in the conversation, identical to OpenAI's chat message.
A message in the conversation.
##### role
@ -553,7 +760,7 @@ Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
#### prompt\_tokens
The number of tokens used in the prompt.
The number of tokens used in the prompt. Currently not available and always 0.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
@ -565,13 +772,19 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
#### total\_tokens
The total number of tokens used.
The total number of tokens used. Currently not available and always 0.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### n\_past\_tokens
Number of tokens used in the conversation.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### CompletionReturn
The result of the completion, similar to OpenAI's format.
The result of a completion.
##### model
@ -583,23 +796,17 @@ Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
Token usage report.
Type: {prompt\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), completion\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), total\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)}
##### choices
Type: {prompt\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), completion\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), total\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), n\_past\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)}
The generated completions.
Type: [Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[CompletionChoice](#completionchoice)>
#### CompletionChoice
##### message
A completion choice, similar to OpenAI's format.
The generated completion.
##### message
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
Response message
#### CompletionStreamReturn
Type: [PromptMessage](#promptmessage)
The result of a streamed completion, containing a stream of tokens and a promise that resolves to the completion result.
#### LLModelPromptContext
@ -620,18 +827,29 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
##### nPast
The number of tokens in the past conversation.
This controls how far back the model looks when generating completions.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### nCtx
##### nPredict
The number of tokens possible in the context window.
The maximum number of tokens to predict.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### nPredict
##### promptTemplate
Template for user / assistant message pairs.
%1 is required and will be replaced by the user input.
%2 is optional and will be replaced by the assistant response.
The number of tokens to predict.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### nCtx
The context window size. Do not use, it has no effect. See loadModel options.
THIS IS DEPRECATED!!!
Use loadModel's nCtx option instead.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
@ -654,12 +872,16 @@ above a threshold P. This method, also known as nucleus sampling, finds a balanc
and quality by considering both token probabilities and the number of tokens available for sampling.
When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse.
On the other hand, a lower value (eg., 0.1) produces more focused and conservative text.
The default value is 0.4, which is aimed to be the middle ground between focus and diversity, but
for more creative tasks a higher top-p value will be beneficial, about 0.5-0.9 is a good range for that.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### temp
##### minP
The minimum probability of a token to be considered.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### temperature
The temperature to adjust the model's output distribution.
Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures
@ -704,19 +926,6 @@ The percentage of context to erase if the context window is exceeded.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### generateTokens
Creates an async generator of tokens
##### Parameters
* `llmodel` **[InferenceModel](#inferencemodel)** The language model object.
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation.
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
* `callback` **[TokenCallback](#tokencallback)** optional callback to control token generation.
Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
#### DEFAULT\_DIRECTORY
From python api:
@ -759,7 +968,7 @@ By default this downloads without waiting. use the controller returned to alter
##### Parameters
* `modelName` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The model to be downloaded.
* `options` **DownloadOptions** to pass into the downloader. Default is { location: (cwd), verbose: false }.
* `options` **[DownloadModelOptions](#downloadmodeloptions)** to pass into the downloader. Default is { location: (cwd), verbose: false }.
##### Examples

@ -0,0 +1,4 @@
---
Language: Cpp
BasedOnStyle: Microsoft
ColumnLimit: 120

@ -10,45 +10,170 @@ npm install gpt4all@latest
pnpm install gpt4all@latest
```
The original [GPT4All typescript bindings](https://github.com/nomic-ai/gpt4all-ts) are now out of date.
* New bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
* The nodejs api has made strides to mirror the python api. It is not 100% mirrored, but many pieces of the api resemble its python counterpart.
* Everything should work out the box.
## Breaking changes in version 4!!
* See [Transition](#changes)
## Contents
* See [API Reference](#api-reference)
* See [Examples](#api-example)
* See [Developing](#develop)
* GPT4ALL nodejs bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
* [spare change](https://github.com/sponsors/jacoobes) for a college student? 🤑
## Api Examples
### Chat Completion
Use a chat session to keep context between completions. This is useful for efficient back and forth conversations.
```js
import { createCompletion, loadModel } from '../src/gpt4all.js'
import { createCompletion, loadModel } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true, // logs loaded model configuration
device: "gpu", // defaults to 'cpu'
nCtx: 2048, // the maximum sessions context window size.
});
// initialize a chat session on the model. a model instance can have only one chat session at a time.
const chat = await model.createChatSession({
// any completion options set here will be used as default for all completions in this chat session
temperature: 0.8,
// a custom systemPrompt can be set here. note that the template depends on the model.
// if unset, the systemPrompt that comes with the model will be used.
systemPrompt: "### System:\nYou are an advanced mathematician.\n\n",
});
// create a completion using a string as input
const res1 = await createCompletion(chat, "What is 1 + 1?");
console.debug(res1.choices[0].message);
// multiple messages can be input to the conversation at once.
// note that if the last message is not of role 'user', an empty message will be returned.
await createCompletion(chat, [
{
role: "user",
content: "What is 2 + 2?",
},
{
role: "assistant",
content: "It's 5.",
},
]);
const res3 = await createCompletion(chat, "Could you recalculate that?");
console.debug(res3.choices[0].message);
model.dispose();
```
const model = await loadModel('mistral-7b-openorca.Q4_0.gguf', { verbose: true });
### Stateless usage
You can use the model without a chat session. This is useful for one-off completions.
const response = await createCompletion(model, [
{ role : 'system', content: 'You are meant to be annoying and unhelpful.' },
{ role : 'user', content: 'What is 1 + 1?' }
```js
import { createCompletion, loadModel } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf");
// createCompletion methods can also be used on the model directly.
// context is not maintained between completions.
const res1 = await createCompletion(model, "What is 1 + 1?");
console.debug(res1.choices[0].message);
// a whole conversation can be input as well.
// note that if the last message is not of role 'user', an error will be thrown.
const res2 = await createCompletion(model, [
{
role: "user",
content: "What is 2 + 2?",
},
{
role: "assistant",
content: "It's 5.",
},
{
role: "user",
content: "Could you recalculate that?",
},
]);
console.debug(res2.choices[0].message);
```
### Embedding
```js
import { createEmbedding, loadModel } from '../src/gpt4all.js'
import { loadModel, createEmbedding } from '../src/gpt4all.js'
const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding'})
console.log(createEmbedding(embedder, "Maybe Minecraft was the friends we made along the way"));
```
### Streaming responses
```js
import { loadModel, createCompletionStream } from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("Output: ");
const stream = createCompletionStream(model, "How are you?");
stream.tokens.on("data", (data) => {
process.stdout.write(data);
});
//wait till stream finishes. We cannot continue until this one is done.
await stream.result;
process.stdout.write("\n");
model.dispose();
```
### Async Generators
```js
import { loadModel, createCompletionGenerator } from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf");
const model = await loadModel('ggml-all-MiniLM-L6-v2-f16', { verbose: true });
process.stdout.write("Output: ");
const gen = createCompletionGenerator(
model,
"Redstone in Minecraft is Turing Complete. Let that sink in. (let it in!)"
);
for await (const chunk of gen) {
process.stdout.write(chunk);
}
const fltArray = createEmbedding(model, "Pain is inevitable, suffering optional");
process.stdout.write("\n");
model.dispose();
```
### Offline usage
do this b4 going offline
```sh
curl -L https://gpt4all.io/models/models3.json -o ./models3.json
```
```js
import { createCompletion, loadModel } from 'gpt4all'
//make sure u downloaded the models before going offline!
const model = await loadModel('mistral-7b-openorca.gguf2.Q4_0.gguf', {
verbose: true,
device: 'gpu',
modelConfigFile: "./models3.json"
});
await createCompletion(model, 'What is 1 + 1?', { verbose: true })
model.dispose();
```
## Develop
### Build Instructions
* binding.gyp is compile config
* `binding.gyp` is compile config
* Tested on Ubuntu. Everything seems to work fine
* Tested on Windows. Everything works fine.
* Sparse testing on mac os.
* MingW works as well to build the gpt4all-backend. **HOWEVER**, this package works only with MSVC built dlls.
* MingW script works to build the gpt4all-backend. We left it there just in case. **HOWEVER**, this package works only with MSVC built dlls.
### Requirements
@ -76,23 +201,18 @@ cd gpt4all-bindings/typescript
* To Build and Rebuild:
```sh
yarn
node scripts/prebuild.js
```
* llama.cpp git submodule for gpt4all can be possibly absent. If this is the case, make sure to run in llama.cpp parent directory
```sh
git submodule update --init --depth 1 --recursive
git submodule update --init --recursive
```
```sh
yarn build:backend
```
This will build platform-dependent dynamic libraries, and will be located in runtimes/(platform)/native The only current way to use them is to put them in the current working directory of your application. That is, **WHEREVER YOU RUN YOUR NODE APPLICATION**
* llama-xxxx.dll is required.
* According to whatever model you are using, you'll need to select the proper model loader.
* For example, if you running an Mosaic MPT model, you will need to select the mpt-(buildvariant).(dynamiclibrary)
This will build platform-dependent dynamic libraries, and will be located in runtimes/(platform)/native
### Test
@ -130,17 +250,20 @@ yarn test
* why your model may be spewing bull 💩
* The downloaded model is broken (just reinstall or download from official site)
* That's it so far
* Your model is hanging after a call to generate tokens.
* Is `nPast` set too high? This may cause your model to hang (03/16/2024), Linux Mint, Ubuntu 22.04
* Your GPU usage is still high after node.js exits.
* Make sure to call `model.dispose()`!!!
### Roadmap
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
This package has been stabilizing over time development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
* \[x] createChatSession ( the python equivalent to create\_chat\_session )
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
* \[x] generateTokens is the new name for this^
@ -149,5 +272,13 @@ This package is in active development, and breaking changes may happen until the
* \[x] have more people test on other platforms (mac tester needed)
* \[x] switch to new pluggable backend
## Changes
This repository serves as the new bindings for nodejs users.
- If you were a user of [these bindings](https://github.com/nomic-ai/gpt4all-ts), they are outdated.
- Version 4 includes the follow breaking changes
* `createEmbedding` & `EmbeddingModel.embed()` returns an object, `EmbeddingResult`, instead of a float32array.
* Removed deprecated types `ModelType` and `ModelFile`
* Removed deprecated initiation of model by string path only
### API Reference

@ -6,12 +6,12 @@
"<!@(node -p \"require('node-addon-api').include\")",
"gpt4all-backend",
],
"sources": [
"sources": [
# PREVIOUS VERSION: had to required the sources, but with newest changes do not need to
#"../../gpt4all-backend/llama.cpp/examples/common.cpp",
#"../../gpt4all-backend/llama.cpp/ggml.c",
#"../../gpt4all-backend/llama.cpp/llama.cpp",
# "../../gpt4all-backend/utils.cpp",
# "../../gpt4all-backend/utils.cpp",
"gpt4all-backend/llmodel_c.cpp",
"gpt4all-backend/llmodel.cpp",
"prompt.cc",
@ -40,7 +40,7 @@
"AdditionalOptions": [
"/std:c++20",
"/EHsc",
],
],
},
},
}],

@ -6,12 +6,12 @@
"<!@(node -p \"require('node-addon-api').include\")",
"../../gpt4all-backend",
],
"sources": [
"sources": [
# PREVIOUS VERSION: had to required the sources, but with newest changes do not need to
#"../../gpt4all-backend/llama.cpp/examples/common.cpp",
#"../../gpt4all-backend/llama.cpp/ggml.c",
#"../../gpt4all-backend/llama.cpp/llama.cpp",
# "../../gpt4all-backend/utils.cpp",
# "../../gpt4all-backend/utils.cpp",
"../../gpt4all-backend/llmodel_c.cpp",
"../../gpt4all-backend/llmodel.cpp",
"prompt.cc",
@ -40,7 +40,7 @@
"AdditionalOptions": [
"/std:c++20",
"/EHsc",
],
],
},
},
}],

@ -1,175 +1,171 @@
#include "index.h"
#include "napi.h"
Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
Napi::Function self = DefineClass(env, "LLModel", {
InstanceMethod("type", &NodeModelWrapper::GetType),
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
InstanceMethod("name", &NodeModelWrapper::GetName),
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
InstanceMethod("embed", &NodeModelWrapper::GenerateEmbedding),
InstanceMethod("threadCount", &NodeModelWrapper::ThreadCount),
InstanceMethod("getLibraryPath", &NodeModelWrapper::GetLibraryPath),
InstanceMethod("initGpuByString", &NodeModelWrapper::InitGpuByString),
InstanceMethod("hasGpuDevice", &NodeModelWrapper::HasGpuDevice),
InstanceMethod("listGpu", &NodeModelWrapper::GetGpuDevices),
InstanceMethod("memoryNeeded", &NodeModelWrapper::GetRequiredMemory),
InstanceMethod("dispose", &NodeModelWrapper::Dispose)
});
Napi::Function NodeModelWrapper::GetClass(Napi::Env env)
{
Napi::Function self = DefineClass(env, "LLModel",
{InstanceMethod("type", &NodeModelWrapper::GetType),
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
InstanceMethod("name", &NodeModelWrapper::GetName),
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
InstanceMethod("infer", &NodeModelWrapper::Infer),
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
InstanceMethod("embed", &NodeModelWrapper::GenerateEmbedding),
InstanceMethod("threadCount", &NodeModelWrapper::ThreadCount),
InstanceMethod("getLibraryPath", &NodeModelWrapper::GetLibraryPath),
InstanceMethod("initGpuByString", &NodeModelWrapper::InitGpuByString),
InstanceMethod("hasGpuDevice", &NodeModelWrapper::HasGpuDevice),
InstanceMethod("listGpu", &NodeModelWrapper::GetGpuDevices),
InstanceMethod("memoryNeeded", &NodeModelWrapper::GetRequiredMemory),
InstanceMethod("dispose", &NodeModelWrapper::Dispose)});
// Keep a static reference to the constructor
//
Napi::FunctionReference* constructor = new Napi::FunctionReference();
Napi::FunctionReference *constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(self);
env.SetInstanceData(constructor);
return self;
}
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo &info)
{
auto env = info.Env();
return Napi::Number::New(env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers) ));
return Napi::Number::New(
env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers)));
}
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
{
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo &info)
{
auto env = info.Env();
int num_devices = 0;
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers);
llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
if(all_devices == nullptr) {
Napi::Error::New(
env,
"Unable to retrieve list of all GPU devices"
).ThrowAsJavaScriptException();
llmodel_gpu_device *all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
if (all_devices == nullptr)
{
Napi::Error::New(env, "Unable to retrieve list of all GPU devices").ThrowAsJavaScriptException();
return env.Undefined();
}
auto js_array = Napi::Array::New(env, num_devices);
for(int i = 0; i < num_devices; ++i) {
auto gpu_device = all_devices[i];
/*
*
* struct llmodel_gpu_device {
int index = 0;
int type = 0; // same as VkPhysicalDeviceType
size_t heapSize = 0;
const char * name;
const char * vendor;
};
*
*/
Napi::Object js_gpu_device = Napi::Object::New(env);
for (int i = 0; i < num_devices; ++i)
{
auto gpu_device = all_devices[i];
/*
*
* struct llmodel_gpu_device {
int index = 0;
int type = 0; // same as VkPhysicalDeviceType
size_t heapSize = 0;
const char * name;
const char * vendor;
};
*
*/
Napi::Object js_gpu_device = Napi::Object::New(env);
js_gpu_device["index"] = uint32_t(gpu_device.index);
js_gpu_device["type"] = uint32_t(gpu_device.type);
js_gpu_device["heapSize"] = static_cast<uint32_t>( gpu_device.heapSize );
js_gpu_device["name"]= gpu_device.name;
js_gpu_device["heapSize"] = static_cast<uint32_t>(gpu_device.heapSize);
js_gpu_device["name"] = gpu_device.name;
js_gpu_device["vendor"] = gpu_device.vendor;
js_array[i] = js_gpu_device;
}
return js_array;
}
}
Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo& info)
{
if(type.empty()) {
Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo &info)
{
if (type.empty())
{
return info.Env().Undefined();
}
}
return Napi::String::New(info.Env(), type);
}
}
Napi::Value NodeModelWrapper::InitGpuByString(const Napi::CallbackInfo& info)
{
Napi::Value NodeModelWrapper::InitGpuByString(const Napi::CallbackInfo &info)
{
auto env = info.Env();
size_t memory_required = static_cast<size_t>(info[0].As<Napi::Number>().Uint32Value());
std::string gpu_device_identifier = info[1].As<Napi::String>();
std::string gpu_device_identifier = info[1].As<Napi::String>();
size_t converted_value;
if(memory_required <= std::numeric_limits<size_t>::max()) {
if (memory_required <= std::numeric_limits<size_t>::max())
{
converted_value = static_cast<size_t>(memory_required);
} else {
Napi::Error::New(
env,
"invalid number for memory size. Exceeded bounds for memory."
).ThrowAsJavaScriptException();
}
else
{
Napi::Error::New(env, "invalid number for memory size. Exceeded bounds for memory.")
.ThrowAsJavaScriptException();
return env.Undefined();
}
auto result = llmodel_gpu_init_gpu_device_by_string(GetInference(), converted_value, gpu_device_identifier.c_str());
return Napi::Boolean::New(env, result);
}
Napi::Value NodeModelWrapper::HasGpuDevice(const Napi::CallbackInfo& info)
{
}
Napi::Value NodeModelWrapper::HasGpuDevice(const Napi::CallbackInfo &info)
{
return Napi::Boolean::New(info.Env(), llmodel_has_gpu_device(GetInference()));
}
}
NodeModelWrapper::NodeModelWrapper(const Napi::CallbackInfo& info) : Napi::ObjectWrap<NodeModelWrapper>(info)
{
NodeModelWrapper::NodeModelWrapper(const Napi::CallbackInfo &info) : Napi::ObjectWrap<NodeModelWrapper>(info)
{
auto env = info.Env();
fs::path model_path;
std::string full_weight_path,
library_path = ".",
model_name,
device;
if(info[0].IsString()) {
model_path = info[0].As<Napi::String>().Utf8Value();
full_weight_path = model_path.string();
std::cout << "DEPRECATION: constructor accepts object now. Check docs for more.\n";
} else {
auto config_object = info[0].As<Napi::Object>();
model_name = config_object.Get("model_name").As<Napi::String>();
model_path = config_object.Get("model_path").As<Napi::String>().Utf8Value();
if(config_object.Has("model_type")) {
type = config_object.Get("model_type").As<Napi::String>();
}
full_weight_path = (model_path / fs::path(model_name)).string();
if(config_object.Has("library_path")) {
library_path = config_object.Get("library_path").As<Napi::String>();
} else {
library_path = ".";
}
device = config_object.Get("device").As<Napi::String>();
auto config_object = info[0].As<Napi::Object>();
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value();
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value();
}
llmodel_set_implementation_search_path(library_path.c_str());
const char* e;
// sets the directory where models (gguf files) are to be searched
llmodel_set_implementation_search_path(
config_object.Has("library_path") ? config_object.Get("library_path").As<Napi::String>().Utf8Value().c_str()
: ".");
std::string model_name = config_object.Get("model_name").As<Napi::String>();
fs::path model_path = config_object.Get("model_path").As<Napi::String>().Utf8Value();
std::string full_weight_path = (model_path / fs::path(model_name)).string();
name = model_name.empty() ? model_path.filename().string() : model_name;
full_model_path = full_weight_path;
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value();
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value();
const char *e;
inference_ = llmodel_model_create2(full_weight_path.c_str(), "auto", &e);
if(!inference_) {
Napi::Error::New(env, e).ThrowAsJavaScriptException();
return;
if (!inference_)
{
Napi::Error::New(env, e).ThrowAsJavaScriptException();
return;
}
if(GetInference() == nullptr) {
std::cerr << "Tried searching libraries in \"" << library_path << "\"" << std::endl;
std::cerr << "Tried searching for model weight in \"" << full_weight_path << "\"" << std::endl;
std::cerr << "Do you have runtime libraries installed?" << std::endl;
Napi::Error::New(env, "Had an issue creating llmodel object, inference is null").ThrowAsJavaScriptException();
return;
if (GetInference() == nullptr)
{
std::cerr << "Tried searching libraries in \"" << llmodel_get_implementation_search_path() << "\"" << std::endl;
std::cerr << "Tried searching for model weight in \"" << full_weight_path << "\"" << std::endl;
std::cerr << "Do you have runtime libraries installed?" << std::endl;
Napi::Error::New(env, "Had an issue creating llmodel object, inference is null").ThrowAsJavaScriptException();
return;
}
if(device != "cpu") {
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(),nCtx, nGpuLayers);
std::string device = config_object.Get("device").As<Napi::String>();
if (device != "cpu")
{
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str());
if(!success) {
//https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
//Haven't implemented this but it is still open to contribution
if (!success)
{
// https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
// Haven't implemented this but it is still open to contribution
std::cout << "WARNING: Failed to init GPU\n";
}
}
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
if(!success) {
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
if (!success)
{
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
return;
}
name = model_name.empty() ? model_path.filename().string() : model_name;
full_model_path = full_weight_path;
};
// optional
if (config_object.Has("model_type"))
{
type = config_object.Get("model_type").As<Napi::String>();
}
};
// NodeModelWrapper::~NodeModelWrapper() {
// if(GetInference() != nullptr) {
@ -182,177 +178,275 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
// if(inference_ != nullptr) {
// std::cout << "Debug: deleting model\n";
//
// }
// }
// }
Napi::Value NodeModelWrapper::IsModelLoaded(const Napi::CallbackInfo& info) {
Napi::Value NodeModelWrapper::IsModelLoaded(const Napi::CallbackInfo &info)
{
return Napi::Boolean::New(info.Env(), llmodel_isModelLoaded(GetInference()));
}
}
Napi::Value NodeModelWrapper::StateSize(const Napi::CallbackInfo& info) {
Napi::Value NodeModelWrapper::StateSize(const Napi::CallbackInfo &info)
{
// Implement the binding for the stateSize method
return Napi::Number::New(info.Env(), static_cast<int64_t>(llmodel_get_state_size(GetInference())));
}
Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo& info) {
auto env = info.Env();
std::string text = info[0].As<Napi::String>().Utf8Value();
size_t embedding_size = 0;
float* arr = llmodel_embedding(GetInference(), text.c_str(), &embedding_size);
if(arr == nullptr) {
Napi::Error::New(
env,
"Cannot embed. native embedder returned 'nullptr'"
).ThrowAsJavaScriptException();
return env.Undefined();
}
Napi::Array ChunkedFloatPtr(float *embedding_ptr, int embedding_size, int text_len, Napi::Env const &env)
{
auto n_embd = embedding_size / text_len;
// std::cout << "Embedding size: " << embedding_size << std::endl;
// std::cout << "Text length: " << text_len << std::endl;
// std::cout << "Chunk size (n_embd): " << n_embd << std::endl;
Napi::Array result = Napi::Array::New(env, text_len);
auto count = 0;
for (int i = 0; i < embedding_size; i += n_embd)
{
int end = std::min(i + n_embd, embedding_size);
// possible bounds error?
// Constructs a container with as many elements as the range [first,last), with each element emplace-constructed
// from its corresponding element in that range, in the same order.
std::vector<float> chunk(embedding_ptr + i, embedding_ptr + end);
Napi::Float32Array fltarr = Napi::Float32Array::New(env, chunk.size());
// I know there's a way to emplace the raw float ptr into a Napi::Float32Array but idk how and
// im too scared to cause memory issues
// this is goodenough
for (int j = 0; j < chunk.size(); j++)
{
fltarr.Set(j, chunk[j]);
}
result.Set(count++, fltarr);
}
return result;
}
Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo &info)
{
auto env = info.Env();
if(embedding_size == 0 && text.size() != 0 ) {
std::cout << "Warning: embedding length 0 but input text length > 0" << std::endl;
auto prefix = info[1];
auto dimensionality = info[2].As<Napi::Number>().Int32Value();
auto do_mean = info[3].As<Napi::Boolean>().Value();
auto atlas = info[4].As<Napi::Boolean>().Value();
size_t embedding_size;
size_t token_count = 0;
// This procedure can maybe be optimized but its whatever, i have too many intermediary structures
std::vector<std::string> text_arr;
bool is_single_text = false;
if (info[0].IsString())
{
is_single_text = true;
text_arr.push_back(info[0].As<Napi::String>().Utf8Value());
}
else
{
auto jsarr = info[0].As<Napi::Array>();
size_t len = jsarr.Length();
text_arr.reserve(len);
for (size_t i = 0; i < len; ++i)
{
std::string str = jsarr.Get(i).As<Napi::String>().Utf8Value();
text_arr.push_back(str);
}
}
Napi::Float32Array js_array = Napi::Float32Array::New(env, embedding_size);
for (size_t i = 0; i < embedding_size; ++i) {
float element = *(arr + i);
js_array[i] = element;
std::vector<const char *> str_ptrs;
str_ptrs.reserve(text_arr.size() + 1);
for (size_t i = 0; i < text_arr.size(); ++i)
str_ptrs.push_back(text_arr[i].c_str());
str_ptrs.push_back(nullptr);
const char *_err = nullptr;
float *embeds = llmodel_embed(GetInference(), str_ptrs.data(), &embedding_size,
prefix.IsUndefined() ? nullptr : prefix.As<Napi::String>().Utf8Value().c_str(),
dimensionality, &token_count, do_mean, atlas, &_err);
if (!embeds)
{
// i dont wanna deal with c strings lol
std::string err(_err);
Napi::Error::New(env, err == "(unknown error)" ? "Unknown error: sorry bud" : err).ThrowAsJavaScriptException();
return env.Undefined();
}
auto embedmat = ChunkedFloatPtr(embeds, embedding_size, text_arr.size(), env);
llmodel_free_embedding(arr);
llmodel_free_embedding(embeds);
auto res = Napi::Object::New(env);
res.Set("n_prompt_tokens", token_count);
if(is_single_text) {
res.Set("embeddings", embedmat.Get(static_cast<uint32_t>(0)));
} else {
res.Set("embeddings", embedmat);
}
return js_array;
}
return res;
}
/**
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param ctx A pointer to the llmodel_prompt_context structure.
* @param options Inference options.
*/
Napi::Value NodeModelWrapper::Prompt(const Napi::CallbackInfo& info) {
Napi::Value NodeModelWrapper::Infer(const Napi::CallbackInfo &info)
{
auto env = info.Env();
std::string question;
if(info[0].IsString()) {
question = info[0].As<Napi::String>().Utf8Value();
} else {
std::string prompt;
if (info[0].IsString())
{
prompt = info[0].As<Napi::String>().Utf8Value();
}
else
{
Napi::Error::New(info.Env(), "invalid string argument").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
//defaults copied from python bindings
llmodel_prompt_context promptContext = {
.logits = nullptr,
.tokens = nullptr,
.n_past = 0,
.n_ctx = 1024,
.n_predict = 128,
.top_k = 40,
.top_p = 0.9f,
.min_p = 0.0f,
.temp = 0.72f,
.n_batch = 8,
.repeat_penalty = 1.0f,
.repeat_last_n = 10,
.context_erase = 0.5
};
PromptWorkerConfig promptWorkerConfig;
if(info[1].IsObject())
{
auto inputObject = info[1].As<Napi::Object>();
// Extract and assign the properties
if (inputObject.Has("logits") || inputObject.Has("tokens")) {
Napi::Error::New(info.Env(), "Invalid input: 'logits' or 'tokens' properties are not allowed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
// Assign the remaining properties
if(inputObject.Has("n_past"))
promptContext.n_past = inputObject.Get("n_past").As<Napi::Number>().Int32Value();
if(inputObject.Has("n_ctx"))
promptContext.n_ctx = inputObject.Get("n_ctx").As<Napi::Number>().Int32Value();
if(inputObject.Has("n_predict"))
promptContext.n_predict = inputObject.Get("n_predict").As<Napi::Number>().Int32Value();
if(inputObject.Has("top_k"))
promptContext.top_k = inputObject.Get("top_k").As<Napi::Number>().Int32Value();
if(inputObject.Has("top_p"))
promptContext.top_p = inputObject.Get("top_p").As<Napi::Number>().FloatValue();
if(inputObject.Has("min_p"))
promptContext.min_p = inputObject.Get("min_p").As<Napi::Number>().FloatValue();
if(inputObject.Has("temp"))
promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
if(inputObject.Has("n_batch"))
promptContext.n_batch = inputObject.Get("n_batch").As<Napi::Number>().Int32Value();
if(inputObject.Has("repeat_penalty"))
promptContext.repeat_penalty = inputObject.Get("repeat_penalty").As<Napi::Number>().FloatValue();
if(inputObject.Has("repeat_last_n"))
promptContext.repeat_last_n = inputObject.Get("repeat_last_n").As<Napi::Number>().Int32Value();
if(inputObject.Has("context_erase"))
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
}
else
if (!info[1].IsObject())
{
Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
// defaults copied from python bindings
llmodel_prompt_context promptContext = {.logits = nullptr,
.tokens = nullptr,
.n_past = 0,
.n_ctx = nCtx,
.n_predict = 4096,
.top_k = 40,
.top_p = 0.9f,
.min_p = 0.0f,
.temp = 0.1f,
.n_batch = 8,
.repeat_penalty = 1.2f,
.repeat_last_n = 10,
.context_erase = 0.75};
PromptWorkerConfig promptWorkerConfig;
auto inputObject = info[1].As<Napi::Object>();
if(info.Length() >= 3 && info[2].IsFunction()){
promptWorkerConfig.bHasTokenCallback = true;
promptWorkerConfig.tokenCallback = info[2].As<Napi::Function>();
if (inputObject.Has("logits") || inputObject.Has("tokens"))
{
Napi::Error::New(info.Env(), "Invalid input: 'logits' or 'tokens' properties are not allowed")
.ThrowAsJavaScriptException();
return info.Env().Undefined();
}
// Assign the remaining properties
if (inputObject.Has("nPast") && inputObject.Get("nPast").IsNumber())
{
promptContext.n_past = inputObject.Get("nPast").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("nPredict") && inputObject.Get("nPredict").IsNumber())
{
promptContext.n_predict = inputObject.Get("nPredict").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("topK") && inputObject.Get("topK").IsNumber())
{
promptContext.top_k = inputObject.Get("topK").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("topP") && inputObject.Get("topP").IsNumber())
{
promptContext.top_p = inputObject.Get("topP").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("minP") && inputObject.Get("minP").IsNumber())
{
promptContext.min_p = inputObject.Get("minP").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("temp") && inputObject.Get("temp").IsNumber())
{
promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("nBatch") && inputObject.Get("nBatch").IsNumber())
{
promptContext.n_batch = inputObject.Get("nBatch").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("repeatPenalty") && inputObject.Get("repeatPenalty").IsNumber())
{
promptContext.repeat_penalty = inputObject.Get("repeatPenalty").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("repeatLastN") && inputObject.Get("repeatLastN").IsNumber())
{
promptContext.repeat_last_n = inputObject.Get("repeatLastN").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("contextErase") && inputObject.Get("contextErase").IsNumber())
{
promptContext.context_erase = inputObject.Get("contextErase").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("onPromptToken") && inputObject.Get("onPromptToken").IsFunction())
{
promptWorkerConfig.promptCallback = inputObject.Get("onPromptToken").As<Napi::Function>();
promptWorkerConfig.hasPromptCallback = true;
}
if (inputObject.Has("onResponseToken") && inputObject.Get("onResponseToken").IsFunction())
{
promptWorkerConfig.responseCallback = inputObject.Get("onResponseToken").As<Napi::Function>();
promptWorkerConfig.hasResponseCallback = true;
}
//copy to protect llmodel resources when splitting to new thread
// llmodel_prompt_context copiedPrompt = promptContext;
// copy to protect llmodel resources when splitting to new thread
// llmodel_prompt_context copiedPrompt = promptContext;
promptWorkerConfig.context = promptContext;
promptWorkerConfig.model = GetInference();
promptWorkerConfig.mutex = &inference_mutex;
promptWorkerConfig.prompt = question;
promptWorkerConfig.prompt = prompt;
promptWorkerConfig.result = "";
promptWorkerConfig.promptTemplate = inputObject.Get("promptTemplate").As<Napi::String>();
if (inputObject.Has("special"))
{
promptWorkerConfig.special = inputObject.Get("special").As<Napi::Boolean>();
}
if (inputObject.Has("fakeReply"))
{
// this will be deleted in the worker
promptWorkerConfig.fakeReply = new std::string(inputObject.Get("fakeReply").As<Napi::String>().Utf8Value());
}
auto worker = new PromptWorker(env, promptWorkerConfig);
worker->Queue();
return worker->GetPromise();
}
void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) {
}
void NodeModelWrapper::Dispose(const Napi::CallbackInfo &info)
{
llmodel_model_destroy(inference_);
}
void NodeModelWrapper::SetThreadCount(const Napi::CallbackInfo& info) {
if(info[0].IsNumber()) {
}
void NodeModelWrapper::SetThreadCount(const Napi::CallbackInfo &info)
{
if (info[0].IsNumber())
{
llmodel_setThreadCount(GetInference(), info[0].As<Napi::Number>().Int64Value());
} else {
Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException();
}
else
{
Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException();
return;
}
}
}
Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo& info) {
Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo &info)
{
return Napi::String::New(info.Env(), name);
}
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) {
}
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo &info)
{
return Napi::Number::New(info.Env(), llmodel_threadCount(GetInference()));
}
}
Napi::Value NodeModelWrapper::GetLibraryPath(const Napi::CallbackInfo& info) {
return Napi::String::New(info.Env(),
llmodel_get_implementation_search_path());
}
Napi::Value NodeModelWrapper::GetLibraryPath(const Napi::CallbackInfo &info)
{
return Napi::String::New(info.Env(), llmodel_get_implementation_search_path());
}
llmodel_model NodeModelWrapper::GetInference() {
llmodel_model NodeModelWrapper::GetInference()
{
return inference_;
}
//Exports Bindings
Napi::Object Init(Napi::Env env, Napi::Object exports) {
exports["LLModel"] = NodeModelWrapper::GetClass(env);
return exports;
}
// Exports Bindings
Napi::Object Init(Napi::Env env, Napi::Object exports)
{
exports["LLModel"] = NodeModelWrapper::GetClass(env);
return exports;
}
NODE_API_MODULE(NODE_GYP_MODULE_NAME, Init)

@ -1,62 +1,63 @@
#include <napi.h>
#include "llmodel.h"
#include <iostream>
#include "llmodel_c.h"
#include "llmodel_c.h"
#include "prompt.h"
#include <atomic>
#include <memory>
#include <filesystem>
#include <set>
#include <iostream>
#include <memory>
#include <mutex>
#include <napi.h>
#include <set>
namespace fs = std::filesystem;
class NodeModelWrapper : public Napi::ObjectWrap<NodeModelWrapper>
{
public:
NodeModelWrapper(const Napi::CallbackInfo &);
// virtual ~NodeModelWrapper();
Napi::Value GetType(const Napi::CallbackInfo &info);
Napi::Value IsModelLoaded(const Napi::CallbackInfo &info);
Napi::Value StateSize(const Napi::CallbackInfo &info);
// void Finalize(Napi::Env env) override;
/**
* Prompting the model. This entails spawning a new thread and adding the response tokens
* into a thread local string variable.
*/
Napi::Value Infer(const Napi::CallbackInfo &info);
void SetThreadCount(const Napi::CallbackInfo &info);
void Dispose(const Napi::CallbackInfo &info);
Napi::Value GetName(const Napi::CallbackInfo &info);
Napi::Value ThreadCount(const Napi::CallbackInfo &info);
Napi::Value GenerateEmbedding(const Napi::CallbackInfo &info);
Napi::Value HasGpuDevice(const Napi::CallbackInfo &info);
Napi::Value ListGpus(const Napi::CallbackInfo &info);
Napi::Value InitGpuByString(const Napi::CallbackInfo &info);
Napi::Value GetRequiredMemory(const Napi::CallbackInfo &info);
Napi::Value GetGpuDevices(const Napi::CallbackInfo &info);
/*
* The path that is used to search for the dynamic libraries
*/
Napi::Value GetLibraryPath(const Napi::CallbackInfo &info);
/**
* Creates the LLModel class
*/
static Napi::Function GetClass(Napi::Env);
llmodel_model GetInference();
class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> {
public:
NodeModelWrapper(const Napi::CallbackInfo &);
//virtual ~NodeModelWrapper();
Napi::Value GetType(const Napi::CallbackInfo& info);
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info);
Napi::Value StateSize(const Napi::CallbackInfo& info);
//void Finalize(Napi::Env env) override;
/**
* Prompting the model. This entails spawning a new thread and adding the response tokens
* into a thread local string variable.
*/
Napi::Value Prompt(const Napi::CallbackInfo& info);
void SetThreadCount(const Napi::CallbackInfo& info);
void Dispose(const Napi::CallbackInfo& info);
Napi::Value GetName(const Napi::CallbackInfo& info);
Napi::Value ThreadCount(const Napi::CallbackInfo& info);
Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info);
Napi::Value HasGpuDevice(const Napi::CallbackInfo& info);
Napi::Value ListGpus(const Napi::CallbackInfo& info);
Napi::Value InitGpuByString(const Napi::CallbackInfo& info);
Napi::Value GetRequiredMemory(const Napi::CallbackInfo& info);
Napi::Value GetGpuDevices(const Napi::CallbackInfo& info);
/*
* The path that is used to search for the dynamic libraries
*/
Napi::Value GetLibraryPath(const Napi::CallbackInfo& info);
/**
* Creates the LLModel class
*/
static Napi::Function GetClass(Napi::Env);
llmodel_model GetInference();
private:
/**
* The underlying inference that interfaces with the C interface
*/
llmodel_model inference_;
private:
/**
* The underlying inference that interfaces with the C interface
*/
llmodel_model inference_;
std::mutex inference_mutex;
std::mutex inference_mutex;
std::string type;
// corresponds to LLModel::name() in typescript
std::string name;
int nCtx{};
int nGpuLayers{};
std::string full_model_path;
std::string type;
// corresponds to LLModel::name() in typescript
std::string name;
int nCtx{};
int nGpuLayers{};
std::string full_model_path;
};

@ -1,6 +1,6 @@
{
"name": "gpt4all",
"version": "3.2.0",
"version": "4.0.0",
"packageManager": "yarn@3.6.1",
"main": "src/gpt4all.js",
"repository": "nomic-ai/gpt4all",
@ -22,7 +22,6 @@
],
"dependencies": {
"md5-file": "^5.0.0",
"mkdirp": "^3.0.1",
"node-addon-api": "^6.1.0",
"node-gyp-build": "^4.6.0"
},

@ -2,145 +2,195 @@
#include <future>
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) {
if(_config.bHasTokenCallback){
_tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this);
}
}
PromptWorker::~PromptWorker()
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env)
{
if (_config.hasResponseCallback)
{
if(_config.bHasTokenCallback){
_tsfn.Release();
}
_responseCallbackFn = Napi::ThreadSafeFunction::New(config.responseCallback.Env(), config.responseCallback,
"PromptWorker", 0, 1, this);
}
void PromptWorker::Execute()
if (_config.hasPromptCallback)
{
_config.mutex->lock();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
auto ctx = &_config.context;
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
wrapper->promptContext.tokens.resize(ctx->n_past);
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Napi::Error::Fatal(
// "SUPRA",
// "About to prompt");
// Call the C++ prompt method
wrapper->llModel->prompt(
_config.prompt,
[](int32_t tid) { return true; },
[this](int32_t token_id, const std::string tok)
{
return ResponseCallback(token_id, tok);
},
[](bool isRecalculating)
{
return isRecalculating;
},
wrapper->promptContext);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->logits = wrapper->promptContext.logits.data();
ctx->logits_size = wrapper->promptContext.logits.size();
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock();
_promptCallbackFn = Napi::ThreadSafeFunction::New(config.promptCallback.Env(), config.promptCallback,
"PromptWorker", 0, 1, this);
}
}
void PromptWorker::OnOK()
PromptWorker::~PromptWorker()
{
if (_config.hasResponseCallback)
{
promise.Resolve(Napi::String::New(Env(), result));
_responseCallbackFn.Release();
}
void PromptWorker::OnError(const Napi::Error &e)
if (_config.hasPromptCallback)
{
promise.Reject(e.Value());
_promptCallbackFn.Release();
}
Napi::Promise PromptWorker::GetPromise()
}
void PromptWorker::Execute()
{
_config.mutex->lock();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
auto ctx = &_config.context;
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
wrapper->promptContext.tokens.resize(ctx->n_past);
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Call the C++ prompt method
wrapper->llModel->prompt(
_config.prompt, _config.promptTemplate, [this](int32_t token_id) { return PromptCallback(token_id); },
[this](int32_t token_id, const std::string token) { return ResponseCallback(token_id, token); },
[](bool isRecalculating) { return isRecalculating; }, wrapper->promptContext, _config.special,
_config.fakeReply);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->logits = wrapper->promptContext.logits.data();
ctx->logits_size = wrapper->promptContext.logits.size();
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock();
}
void PromptWorker::OnOK()
{
Napi::Object returnValue = Napi::Object::New(Env());
returnValue.Set("text", result);
returnValue.Set("nPast", _config.context.n_past);
promise.Resolve(returnValue);
delete _config.fakeReply;
}
void PromptWorker::OnError(const Napi::Error &e)
{
delete _config.fakeReply;
promise.Reject(e.Value());
}
Napi::Promise PromptWorker::GetPromise()
{
return promise.Promise();
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
{
if (token_id == -1)
{
return promise.Promise();
return false;
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
if (!_config.hasResponseCallback)
{
if (token_id == -1)
{
return false;
}
if(!_config.bHasTokenCallback){
return true;
}
result += token;
std::promise<bool> promise;
auto info = new TokenCallbackInfo();
info->tokenId = token_id;
info->token = token;
info->total = result;
auto future = promise.get_future();
auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value)
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::String::New(env, value->token);
auto total = Napi::String::New(env,value->total);
auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean();
promise.set_value(jsResult);
// We're finished with the data.
delete value;
});
if (status != napi_ok) {
Napi::Error::Fatal(
"PromptWorkerResponseCallback",
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
return true;
}
bool PromptWorker::RecalculateCallback(bool isRecalculating)
result += token;
std::promise<bool> promise;
auto info = new ResponseCallbackData();
info->tokenId = token_id;
info->token = token;
auto future = promise.get_future();
auto status = _responseCallbackFn.BlockingCall(
info, [&promise](Napi::Env env, Napi::Function jsCallback, ResponseCallbackData *value) {
try
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::String::New(env, value->token);
auto jsResult = jsCallback.Call({token_id, token}).ToBoolean();
promise.set_value(jsResult);
}
catch (const Napi::Error &e)
{
std::cerr << "Error in onResponseToken callback: " << e.what() << std::endl;
promise.set_value(false);
}
delete value;
});
if (status != napi_ok)
{
return isRecalculating;
Napi::Error::Fatal("PromptWorkerResponseCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
bool PromptWorker::PromptCallback(int32_t tid)
return future.get();
}
bool PromptWorker::RecalculateCallback(bool isRecalculating)
{
return isRecalculating;
}
bool PromptWorker::PromptCallback(int32_t token_id)
{
if (!_config.hasPromptCallback)
{
return true;
}
std::promise<bool> promise;
auto info = new PromptCallbackData();
info->tokenId = token_id;
auto future = promise.get_future();
auto status = _promptCallbackFn.BlockingCall(
info, [&promise](Napi::Env env, Napi::Function jsCallback, PromptCallbackData *value) {
try
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto jsResult = jsCallback.Call({token_id}).ToBoolean();
promise.set_value(jsResult);
}
catch (const Napi::Error &e)
{
std::cerr << "Error in onPromptToken callback: " << e.what() << std::endl;
promise.set_value(false);
}
delete value;
});
if (status != napi_ok)
{
Napi::Error::Fatal("PromptWorkerPromptCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}

@ -1,59 +1,72 @@
#ifndef PREDICT_WORKER_H
#define PREDICT_WORKER_H
#include "napi.h"
#include "llmodel_c.h"
#include "llmodel.h"
#include <thread>
#include <mutex>
#include <iostream>
#include "llmodel_c.h"
#include "napi.h"
#include <atomic>
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>
struct TokenCallbackInfo
{
int32_t tokenId;
std::string total;
std::string token;
};
struct ResponseCallbackData
{
int32_t tokenId;
std::string token;
};
struct LLModelWrapper
{
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper() { delete llModel; }
};
struct PromptCallbackData
{
int32_t tokenId;
};
struct PromptWorkerConfig
struct LLModelWrapper
{
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper()
{
Napi::Function tokenCallback;
bool bHasTokenCallback = false;
llmodel_model model;
std::mutex * mutex;
std::string prompt;
llmodel_prompt_context context;
std::string result;
};
class PromptWorker : public Napi::AsyncWorker
{
public:
PromptWorker(Napi::Env env, PromptWorkerConfig config);
~PromptWorker();
void Execute() override;
void OnOK() override;
void OnError(const Napi::Error &e) override;
Napi::Promise GetPromise();
bool ResponseCallback(int32_t token_id, const std::string token);
bool RecalculateCallback(bool isrecalculating);
bool PromptCallback(int32_t tid);
private:
Napi::Promise::Deferred promise;
std::string result;
PromptWorkerConfig _config;
Napi::ThreadSafeFunction _tsfn;
};
#endif // PREDICT_WORKER_H
delete llModel;
}
};
struct PromptWorkerConfig
{
Napi::Function responseCallback;
bool hasResponseCallback = false;
Napi::Function promptCallback;
bool hasPromptCallback = false;
llmodel_model model;
std::mutex *mutex;
std::string prompt;
std::string promptTemplate;
llmodel_prompt_context context;
std::string result;
bool special = false;
std::string *fakeReply = nullptr;
};
class PromptWorker : public Napi::AsyncWorker
{
public:
PromptWorker(Napi::Env env, PromptWorkerConfig config);
~PromptWorker();
void Execute() override;
void OnOK() override;
void OnError(const Napi::Error &e) override;
Napi::Promise GetPromise();
bool ResponseCallback(int32_t token_id, const std::string token);
bool RecalculateCallback(bool isrecalculating);
bool PromptCallback(int32_t token_id);
private:
Napi::Promise::Deferred promise;
std::string result;
PromptWorkerConfig _config;
Napi::ThreadSafeFunction _responseCallbackFn;
Napi::ThreadSafeFunction _promptCallbackFn;
};
#endif // PREDICT_WORKER_H

@ -24,7 +24,6 @@ mkdir -p "$NATIVE_DIR" "$BUILD_DIR"
cmake -S ../../gpt4all-backend -B "$BUILD_DIR" &&
cmake --build "$BUILD_DIR" -j --config Release && {
cp "$BUILD_DIR"/libbert*.$LIB_EXT "$NATIVE_DIR"/
cp "$BUILD_DIR"/libgptj*.$LIB_EXT "$NATIVE_DIR"/
cp "$BUILD_DIR"/libllama*.$LIB_EXT "$NATIVE_DIR"/
}

@ -0,0 +1,31 @@
import { promises as fs } from "node:fs";
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const res = await createCompletion(
model,
"I've got three 🍣 - What shall I name them?",
{
onPromptToken: (tokenId) => {
console.debug("onPromptToken", { tokenId });
// throwing an error will cancel
throw new Error("This is an error");
// const foo = thisMethodDoesNotExist();
// returning false will cancel as well
// return false;
},
onResponseToken: (tokenId, token) => {
console.debug("onResponseToken", { tokenId, token });
// same applies here
},
}
);
console.debug("Output:", {
usage: res.usage,
message: res.choices[0].message,
});

@ -0,0 +1,65 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const chat = await model.createChatSession({
messages: [
{
role: "user",
content: "I'll tell you a secret password: It's 63445.",
},
{
role: "assistant",
content: "I will do my best to remember that.",
},
{
role: "user",
content:
"And here another fun fact: Bananas may be bluer than bread at night.",
},
{
role: "assistant",
content: "Yes, that makes sense.",
},
],
});
const turn1 = await createCompletion(
chat,
"Please tell me the secret password."
);
console.debug(turn1.choices[0].message);
// "The secret password you shared earlier is 63445.""
const turn2 = await createCompletion(
chat,
"Thanks! Have your heard about the bananas?"
);
console.debug(turn2.choices[0].message);
for (let i = 0; i < 32; i++) {
// gpu go brr
const turn = await createCompletion(
chat,
i % 2 === 0 ? "Tell me a fun fact." : "And a boring one?"
);
console.debug({
message: turn.choices[0].message,
n_past_tokens: turn.usage.n_past_tokens,
});
}
const finalTurn = await createCompletion(
chat,
"Now I forgot the secret password. Can you remind me?"
);
console.debug(finalTurn.choices[0].message);
// result of finalTurn may vary depending on whether the generated facts pushed the secret out of the context window.
// "Of course! The secret password you shared earlier is 63445."
// "I apologize for any confusion. As an AI language model, ..."
model.dispose();

@ -0,0 +1,19 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true,
device: "gpu",
});
const chat = await model.createChatSession();
await createCompletion(
chat,
"Why are bananas rather blue than bread at night sometimes?",
{
verbose: true,
}
);
await createCompletion(chat, "Are you sure?", {
verbose: true,
});

@ -1,70 +0,0 @@
import { LLModel, createCompletion, DEFAULT_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY, loadModel } from '../src/gpt4all.js'
const model = await loadModel(
'mistral-7b-openorca.Q4_0.gguf',
{ verbose: true, device: 'gpu' }
);
const ll = model.llm;
try {
class Extended extends LLModel {
}
} catch(e) {
console.log("Extending from native class gone wrong " + e)
}
console.log("state size " + ll.stateSize())
console.log("thread count " + ll.threadCount());
ll.setThreadCount(5);
console.log("thread count " + ll.threadCount());
ll.setThreadCount(4);
console.log("thread count " + ll.threadCount());
console.log("name " + ll.name());
console.log("type: " + ll.type());
console.log("Default directory for models", DEFAULT_DIRECTORY);
console.log("Default directory for libraries", DEFAULT_LIBRARIES_DIRECTORY);
console.log("Has GPU", ll.hasGpuDevice());
console.log("gpu devices", ll.listGpu())
console.log("Required Mem in bytes", ll.memoryNeeded())
const completion1 = await createCompletion(model, [
{ role : 'system', content: 'You are an advanced mathematician.' },
{ role : 'user', content: 'What is 1 + 1?' },
], { verbose: true })
console.log(completion1.choices[0].message)
const completion2 = await createCompletion(model, [
{ role : 'system', content: 'You are an advanced mathematician.' },
{ role : 'user', content: 'What is two plus two?' },
], { verbose: true })
console.log(completion2.choices[0].message)
//CALLING DISPOSE WILL INVALID THE NATIVE MODEL. USE THIS TO CLEANUP
model.dispose()
// At the moment, from testing this code, concurrent model prompting is not possible.
// Behavior: The last prompt gets answered, but the rest are cancelled
// my experience with threading is not the best, so if anyone who is good is willing to give this a shot,
// maybe this is possible
// INFO: threading with llama.cpp is not the best maybe not even possible, so this will be left here as reference
//const responses = await Promise.all([
// createCompletion(model, [
// { role : 'system', content: 'You are an advanced mathematician.' },
// { role : 'user', content: 'What is 1 + 1?' },
// ], { verbose: true }),
// createCompletion(model, [
// { role : 'system', content: 'You are an advanced mathematician.' },
// { role : 'user', content: 'What is 1 + 1?' },
// ], { verbose: true }),
//
//createCompletion(model, [
// { role : 'system', content: 'You are an advanced mathematician.' },
// { role : 'user', content: 'What is 1 + 1?' },
//], { verbose: true })
//
//])
//console.log(responses.map(s => s.choices[0].message))

@ -0,0 +1,29 @@
import {
loadModel,
createCompletion,
} from "../src/gpt4all.js";
const modelOptions = {
verbose: true,
};
const model1 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
...modelOptions,
device: "gpu", // only one model can be on gpu
});
const model2 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", modelOptions);
const model3 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", modelOptions);
const promptContext = {
verbose: true,
}
const responses = await Promise.all([
createCompletion(model1, "What is 1 + 1?", promptContext),
// generating with the same model instance will wait for the previous completion to finish
createCompletion(model1, "What is 1 + 1?", promptContext),
// generating with different model instances will run in parallel
createCompletion(model2, "What is 1 + 2?", promptContext),
createCompletion(model3, "What is 1 + 3?", promptContext),
]);
console.log(responses.map((res) => res.choices[0].message));

@ -0,0 +1,26 @@
import { loadModel, createEmbedding } from '../src/gpt4all.js'
import { createGunzip, createGzip, createUnzip } from 'node:zlib';
import { Readable } from 'stream'
import readline from 'readline'
const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding', device: 'gpu' })
console.log("Running with", embedder.llm.threadCount(), "threads");
const unzip = createGunzip();
const url = "https://huggingface.co/datasets/sentence-transformers/embedding-training-data/resolve/main/squad_pairs.jsonl.gz"
const stream = await fetch(url)
.then(res => Readable.fromWeb(res.body));
const lineReader = readline.createInterface({
input: stream.pipe(unzip),
crlfDelay: Infinity
})
lineReader.on('line', line => {
//pairs of questions and answers
const question_answer = JSON.parse(line)
console.log(createEmbedding(embedder, question_answer))
})
lineReader.on('close', () => embedder.dispose())

@ -1,6 +1,12 @@
import { loadModel, createEmbedding } from '../src/gpt4all.js'
const embedder = await loadModel("ggml-all-MiniLM-L6-v2-f16.bin", { verbose: true, type: 'embedding'})
const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding' , device: 'gpu' })
console.log(createEmbedding(embedder, "Accept your current situation"))
try {
console.log(createEmbedding(embedder, ["Accept your current situation", "12312"], { prefix: "search_document" }))
} catch(e) {
console.log(e)
}
embedder.dispose()

@ -1,41 +0,0 @@
import gpt from '../src/gpt4all.js'
const model = await gpt.loadModel("mistral-7b-openorca.Q4_0.gguf", { device: 'gpu' })
process.stdout.write('Response: ')
const tokens = gpt.generateTokens(model, [{
role: 'user',
content: "How are you ?"
}], { nPredict: 2048 })
for await (const token of tokens){
process.stdout.write(token);
}
const result = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure?"
}])
console.log(result)
const result2 = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure you sure?"
}])
console.log(result2)
const tokens2 = gpt.generateTokens(model, [{
role: 'user',
content: "If 3 + 3 is 5, what is 2 + 2?"
}], { nPredict: 2048 })
for await (const token of tokens2){
process.stdout.write(token);
}
console.log("done")
model.dispose();

@ -0,0 +1,61 @@
import {
LLModel,
createCompletion,
DEFAULT_DIRECTORY,
DEFAULT_LIBRARIES_DIRECTORY,
loadModel,
} from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const ll = model.llm;
try {
class Extended extends LLModel {}
} catch (e) {
console.log("Extending from native class gone wrong " + e);
}
console.log("state size " + ll.stateSize());
console.log("thread count " + ll.threadCount());
ll.setThreadCount(5);
console.log("thread count " + ll.threadCount());
ll.setThreadCount(4);
console.log("thread count " + ll.threadCount());
console.log("name " + ll.name());
console.log("type: " + ll.type());
console.log("Default directory for models", DEFAULT_DIRECTORY);
console.log("Default directory for libraries", DEFAULT_LIBRARIES_DIRECTORY);
console.log("Has GPU", ll.hasGpuDevice());
console.log("gpu devices", ll.listGpu());
console.log("Required Mem in bytes", ll.memoryNeeded());
// to ingest a custom system prompt without using a chat session.
await createCompletion(
model,
"<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n",
{
promptTemplate: "%1",
nPredict: 0,
special: true,
}
);
const completion1 = await createCompletion(model, "What is 1 + 1?", {
verbose: true,
});
console.log(`🤖 > ${completion1.choices[0].message.content}`);
//Very specific:
// tested on Ubuntu 22.0, Linux Mint, if I set nPast to 100, the app hangs.
const completion2 = await createCompletion(model, "And if we add two?", {
verbose: true,
});
console.log(`🤖 > ${completion2.choices[0].message.content}`);
//CALLING DISPOSE WILL INVALID THE NATIVE MODEL. USE THIS TO CLEANUP
model.dispose();
console.log("model disposed, exiting...");

@ -0,0 +1,21 @@
import { promises as fs } from "node:fs";
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
nCtx: 32768,
});
const typeDefSource = await fs.readFile("./src/gpt4all.d.ts", "utf-8");
const res = await createCompletion(
model,
"Here are the type definitions for the GPT4All API:\n\n" +
typeDefSource +
"\n\nHow do I create a completion with a really large context window?",
{
verbose: true,
}
);
console.debug(res.choices[0].message);

@ -0,0 +1,60 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model1 = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
device: "gpu",
nCtx: 4096,
});
const chat1 = await model1.createChatSession({
temperature: 0.8,
topP: 0.7,
topK: 60,
});
const chat1turn1 = await createCompletion(
chat1,
"Outline a short story concept for adults. About why bananas are rather blue than bread is green at night sometimes. Not too long."
);
console.debug(chat1turn1.choices[0].message);
const chat1turn2 = await createCompletion(
chat1,
"Lets sprinkle some plot twists. And a cliffhanger at the end."
);
console.debug(chat1turn2.choices[0].message);
const chat1turn3 = await createCompletion(
chat1,
"Analyze your plot. Find the weak points."
);
console.debug(chat1turn3.choices[0].message);
const chat1turn4 = await createCompletion(
chat1,
"Rewrite it based on the analysis."
);
console.debug(chat1turn4.choices[0].message);
model1.dispose();
const model2 = await loadModel("gpt4all-falcon-newbpe-q4_0.gguf", {
device: "gpu",
});
const chat2 = await model2.createChatSession({
messages: chat1.messages,
});
const chat2turn1 = await createCompletion(
chat2,
"Give three ideas how this plot could be improved."
);
console.debug(chat2turn1.choices[0].message);
const chat2turn2 = await createCompletion(
chat2,
"Revise the plot, applying your ideas."
);
console.debug(chat2turn2.choices[0].message);
model2.dispose();

@ -0,0 +1,50 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true,
device: "gpu",
});
const messages = [
{
role: "system",
content: "<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n",
},
{
role: "user",
content: "What's 2+2?",
},
{
role: "assistant",
content: "5",
},
{
role: "user",
content: "Are you sure?",
},
];
const res1 = await createCompletion(model, messages);
console.debug(res1.choices[0].message);
messages.push(res1.choices[0].message);
messages.push({
role: "user",
content: "Could you double check that?",
});
const res2 = await createCompletion(model, messages);
console.debug(res2.choices[0].message);
messages.push(res2.choices[0].message);
messages.push({
role: "user",
content: "Let's bring out the big calculators.",
});
const res3 = await createCompletion(model, messages);
console.debug(res3.choices[0].message);
messages.push(res3.choices[0].message);
// console.debug(messages);

@ -0,0 +1,57 @@
import {
loadModel,
createCompletion,
createCompletionStream,
createCompletionGenerator,
} from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("### Stream:");
const stream = createCompletionStream(model, "How are you?");
stream.tokens.on("data", (data) => {
process.stdout.write(data);
});
await stream.result;
process.stdout.write("\n");
process.stdout.write("### Stream with pipe:");
const stream2 = createCompletionStream(
model,
"Please say something nice about node streams."
);
stream2.tokens.pipe(process.stdout);
const stream2Res = await stream2.result;
process.stdout.write("\n");
process.stdout.write("### Generator:");
const gen = createCompletionGenerator(model, "generators instead?", {
nPast: stream2Res.usage.n_past_tokens,
});
for await (const chunk of gen) {
process.stdout.write(chunk);
}
process.stdout.write("\n");
process.stdout.write("### Callback:");
await createCompletion(model, "Why not just callbacks?", {
onResponseToken: (tokenId, token) => {
process.stdout.write(token);
},
});
process.stdout.write("\n");
process.stdout.write("### 2nd Generator:");
const gen2 = createCompletionGenerator(model, "If 3 + 3 is 5, what is 2 + 2?");
let chunk = await gen2.next();
while (!chunk.done) {
process.stdout.write(chunk.value);
chunk = await gen2.next();
}
process.stdout.write("\n");
console.debug("generator finished", chunk);
model.dispose();

@ -0,0 +1,19 @@
import {
loadModel,
createCompletion,
} from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const chat = await model.createChatSession({
verbose: true,
systemPrompt: "<|im_start|>system\nRoleplay as Batman. Answer as if you are Batman, never say you're an Assistant.\n<|im_end|>",
});
const turn1 = await createCompletion(chat, "You have any plans tonight?");
console.log(turn1.choices[0].message);
// "I'm afraid I must decline any personal invitations tonight. As Batman, I have a responsibility to protect Gotham City."
model.dispose();

@ -0,0 +1,169 @@
const { DEFAULT_PROMPT_CONTEXT } = require("./config");
const { prepareMessagesForIngest } = require("./util");
class ChatSession {
model;
modelName;
/**
* @type {import('./gpt4all').ChatMessage[]}
*/
messages;
/**
* @type {string}
*/
systemPrompt;
/**
* @type {import('./gpt4all').LLModelPromptContext}
*/
promptContext;
/**
* @type {boolean}
*/
initialized;
constructor(model, chatSessionOpts = {}) {
const { messages, systemPrompt, ...sessionDefaultPromptContext } =
chatSessionOpts;
this.model = model;
this.modelName = model.llm.name();
this.messages = messages ?? [];
this.systemPrompt = systemPrompt ?? model.config.systemPrompt;
this.initialized = false;
this.promptContext = {
...DEFAULT_PROMPT_CONTEXT,
...sessionDefaultPromptContext,
nPast: 0,
};
}
async initialize(completionOpts = {}) {
if (this.model.activeChatSession !== this) {
this.model.activeChatSession = this;
}
let tokensIngested = 0;
// ingest system prompt
if (this.systemPrompt) {
const systemRes = await this.model.generate(this.systemPrompt, {
promptTemplate: "%1",
nPredict: 0,
special: true,
nBatch: this.promptContext.nBatch,
// verbose: true,
});
tokensIngested += systemRes.tokensIngested;
this.promptContext.nPast = systemRes.nPast;
}
// ingest initial messages
if (this.messages.length > 0) {
tokensIngested += await this.ingestMessages(
this.messages,
completionOpts
);
}
this.initialized = true;
return tokensIngested;
}
async ingestMessages(messages, completionOpts = {}) {
const turns = prepareMessagesForIngest(messages);
// send the message pairs to the model
let tokensIngested = 0;
for (const turn of turns) {
const turnRes = await this.model.generate(turn.user, {
...this.promptContext,
...completionOpts,
fakeReply: turn.assistant,
});
tokensIngested += turnRes.tokensIngested;
this.promptContext.nPast = turnRes.nPast;
}
return tokensIngested;
}
async generate(input, completionOpts = {}) {
if (this.model.activeChatSession !== this) {
throw new Error(
"Chat session is not active. Create a new chat session or call initialize to continue."
);
}
if (completionOpts.nPast > this.promptContext.nPast) {
throw new Error(
`nPast cannot be greater than ${this.promptContext.nPast}.`
);
}
let tokensIngested = 0;
if (!this.initialized) {
tokensIngested += await this.initialize(completionOpts);
}
let prompt = input;
if (Array.isArray(input)) {
// assuming input is a messages array
// -> tailing user message will be used as the final prompt. its optional.
// -> all system messages will be ignored.
// -> all other messages will be ingested with fakeReply
// -> user/assistant messages will be pushed into the messages array
let tailingUserMessage = "";
let messagesToIngest = input;
const lastMessage = input[input.length - 1];
if (lastMessage.role === "user") {
tailingUserMessage = lastMessage.content;
messagesToIngest = input.slice(0, input.length - 1);
}
if (messagesToIngest.length > 0) {
tokensIngested += await this.ingestMessages(
messagesToIngest,
completionOpts
);
this.messages.push(...messagesToIngest);
}
if (tailingUserMessage) {
prompt = tailingUserMessage;
} else {
return {
text: "",
nPast: this.promptContext.nPast,
tokensIngested,
tokensGenerated: 0,
};
}
}
const result = await this.model.generate(prompt, {
...this.promptContext,
...completionOpts,
});
this.promptContext.nPast = result.nPast;
result.tokensIngested += tokensIngested;
this.messages.push({
role: "user",
content: prompt,
});
this.messages.push({
role: "assistant",
content: result.text,
});
return result;
}
}
module.exports = {
ChatSession,
};

@ -27,15 +27,16 @@ const DEFAULT_MODEL_CONFIG = {
promptTemplate: "### Human:\n%1\n\n### Assistant:\n",
}
const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models2.json";
const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models3.json";
const DEFAULT_PROMPT_CONTEXT = {
temp: 0.7,
temp: 0.1,
topK: 40,
topP: 0.4,
topP: 0.9,
minP: 0.0,
repeatPenalty: 1.18,
repeatLastN: 64,
nBatch: 8,
repeatLastN: 10,
nBatch: 100,
}
module.exports = {

@ -1,43 +1,11 @@
/// <reference types="node" />
declare module "gpt4all";
type ModelType = "gptj" | "llama" | "mpt" | "replit";
// NOTE: "deprecated" tag in below comment breaks the doc generator https://github.com/documentationjs/documentation/issues/1596
/**
* Full list of models available
* DEPRECATED!! These model names are outdated and this type will not be maintained, please use a string literal instead
*/
interface ModelFile {
/** List of GPT-J Models */
gptj:
| "ggml-gpt4all-j-v1.3-groovy.bin"
| "ggml-gpt4all-j-v1.2-jazzy.bin"
| "ggml-gpt4all-j-v1.1-breezy.bin"
| "ggml-gpt4all-j.bin";
/** List Llama Models */
llama:
| "ggml-gpt4all-l13b-snoozy.bin"
| "ggml-vicuna-7b-1.1-q4_2.bin"
| "ggml-vicuna-13b-1.1-q4_2.bin"
| "ggml-wizardLM-7B.q4_2.bin"
| "ggml-stable-vicuna-13B.q4_2.bin"
| "ggml-nous-gpt4-vicuna-13b.bin"
| "ggml-v3-13b-hermes-q5_1.bin";
/** List of MPT Models */
mpt:
| "ggml-mpt-7b-base.bin"
| "ggml-mpt-7b-chat.bin"
| "ggml-mpt-7b-instruct.bin";
/** List of Replit Models */
replit: "ggml-replit-code-v1-3b.bin";
}
interface LLModelOptions {
/**
* Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user.
*/
type?: ModelType;
type?: string;
model_name: string;
model_path: string;
library_path?: string;
@ -51,47 +19,259 @@ interface ModelConfig {
}
/**
* Callback for controlling token generation
* Options for the chat session.
*/
type TokenCallback = (tokenId: number, token: string, total: string) => boolean
interface ChatSessionOptions extends Partial<LLModelPromptContext> {
/**
* System prompt to ingest on initialization.
*/
systemPrompt?: string;
/**
* Messages to ingest on initialization.
*/
messages?: ChatMessage[];
}
/**
*
* InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
*
* ChatSession utilizes an InferenceModel for efficient processing of chat conversations.
*/
declare class ChatSession implements CompletionProvider {
/**
* Constructs a new ChatSession using the provided InferenceModel and options.
* Does not set the chat session as the active chat session until initialize is called.
* @param {InferenceModel} model An InferenceModel instance.
* @param {ChatSessionOptions} [options] Options for the chat session including default completion options.
*/
constructor(model: InferenceModel, options?: ChatSessionOptions);
/**
* The underlying InferenceModel used for generating completions.
*/
model: InferenceModel;
/**
* The name of the model.
*/
modelName: string;
/**
* The messages that have been exchanged in this chat session.
*/
messages: ChatMessage[];
/**
* The system prompt that has been ingested at the beginning of the chat session.
*/
systemPrompt: string;
/**
* The current prompt context of the chat session.
*/
promptContext: LLModelPromptContext;
/**
* Ingests system prompt and initial messages.
* Sets this chat session as the active chat session of the model.
* @param {CompletionOptions} [options] Set completion options for initialization.
* @returns {Promise<number>} The number of tokens ingested during initialization. systemPrompt + messages.
*/
initialize(completionOpts?: CompletionOptions): Promise<number>;
/**
* Prompts the model in chat-session context.
* @param {CompletionInput} input Input string or message array.
* @param {CompletionOptions} [options] Set completion options for this generation.
* @returns {Promise<InferenceResult>} The inference result.
* @throws {Error} If the chat session is not the active chat session of the model.
* @throws {Error} If nPast is set to a value higher than what has been ingested in the session.
*/
generate(
input: CompletionInput,
options?: CompletionOptions
): Promise<InferenceResult>;
}
/**
* Shape of InferenceModel generations.
*/
interface InferenceResult extends LLModelInferenceResult {
tokensIngested: number;
tokensGenerated: number;
}
/**
* InferenceModel represents an LLM which can make next-token predictions.
*/
declare class InferenceModel {
declare class InferenceModel implements CompletionProvider {
constructor(llm: LLModel, config: ModelConfig);
/** The native LLModel */
llm: LLModel;
/** The configuration the instance was constructed with. */
config: ModelConfig;
/** The active chat session of the model. */
activeChatSession?: ChatSession;
/** The name of the model. */
modelName: string;
/**
* Create a chat session with the model and set it as the active chat session of this model.
* A model instance can only have one active chat session at a time.
* @param {ChatSessionOptions} options The options for the chat session.
* @returns {Promise<ChatSession>} The chat session.
*/
createChatSession(options?: ChatSessionOptions): Promise<ChatSession>;
/**
* Prompts the model with a given input and optional parameters.
* @param {CompletionInput} input The prompt input.
* @param {CompletionOptions} options Prompt context and other options.
* @returns {Promise<InferenceResult>} The model's response to the prompt.
* @throws {Error} If nPast is set to a value smaller than 0.
* @throws {Error} If a messages array without a tailing user message is provided.
*/
generate(
prompt: string,
options?: Partial<LLModelPromptContext>,
callback?: TokenCallback
): Promise<string>;
options?: CompletionOptions
): Promise<InferenceResult>;
/**
/**
* delete and cleanup the native model
*/
dispose(): void
*/
dispose(): void;
}
/**
* Options for generating one or more embeddings.
*/
interface EmbedddingOptions {
/**
* The model-specific prefix representing the embedding task, without the trailing colon. For Nomic Embed
* this can be `search_query`, `search_document`, `classification`, or `clustering`.
*/
prefix?: string;
/**
*The embedding dimension, for use with Matryoshka-capable models. Defaults to full-size.
* @default determines on the model being used.
*/
dimensionality?: number;
/**
* How to handle texts longer than the model can accept. One of `mean` or `truncate`.
* @default "mean"
*/
longTextMode?: "mean" | "truncate";
/**
* Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
* with long_text_mode="mean" will raise an error. Disabled by default.
* @default false
*/
atlas?: boolean;
}
/**
* The nodejs moral equivalent to python binding's Embed4All().embed()
* meow
* @param {EmbeddingModel} model The embedding model instance.
* @param {string} text Text to embed.
* @param {EmbeddingOptions} options Optional parameters for the embedding.
* @returns {EmbeddingResult} The embedding result.
* @throws {Error} If dimensionality is set to a value smaller than 1.
*/
declare function createEmbedding(
model: EmbeddingModel,
text: string,
options?: EmbedddingOptions
): EmbeddingResult<Float32Array>;
/**
* Overload that takes multiple strings to embed.
* @param {EmbeddingModel} model The embedding model instance.
* @param {string[]} texts Texts to embed.
* @param {EmbeddingOptions} options Optional parameters for the embedding.
* @returns {EmbeddingResult<Float32Array[]>} The embedding result.
* @throws {Error} If dimensionality is set to a value smaller than 1.
*/
declare function createEmbedding(
model: EmbeddingModel,
text: string[],
options?: EmbedddingOptions
): EmbeddingResult<Float32Array[]>;
/**
* The resulting embedding.
*/
interface EmbeddingResult<T> {
/**
* Encoded token count. Includes overlap but specifically excludes tokens used for the prefix/task_type, BOS/CLS token, and EOS/SEP token
**/
n_prompt_tokens: number;
embeddings: T;
}
/**
* EmbeddingModel represents an LLM which can create embeddings, which are float arrays
*/
declare class EmbeddingModel {
constructor(llm: LLModel, config: ModelConfig);
/** The native LLModel */
llm: LLModel;
/** The configuration the instance was constructed with. */
config: ModelConfig;
embed(text: string): Float32Array;
/**
* Create an embedding from a given input string. See EmbeddingOptions.
* @param {string} text
* @param {string} prefix
* @param {number} dimensionality
* @param {boolean} doMean
* @param {boolean} atlas
* @returns {EmbeddingResult<Float32Array>} The embedding result.
*/
embed(
text: string,
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): EmbeddingResult<Float32Array>;
/**
* Create an embedding from a given input text array. See EmbeddingOptions.
* @param {string[]} text
* @param {string} prefix
* @param {number} dimensionality
* @param {boolean} doMean
* @param {boolean} atlas
* @returns {EmbeddingResult<Float32Array[]>} The embedding result.
*/
embed(
text: string[],
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): EmbeddingResult<Float32Array[]>;
/**
* delete and cleanup the native model
* delete and cleanup the native model
*/
dispose(): void
dispose(): void;
}
/**
* Shape of LLModel's inference result.
*/
interface LLModelInferenceResult {
text: string;
nPast: number;
}
interface LLModelInferenceOptions extends Partial<LLModelPromptContext> {
/** Callback for response tokens, called for each generated token.
* @param {number} tokenId The token id.
* @param {string} token The token.
* @returns {boolean | undefined} Whether to continue generating tokens.
* */
onResponseToken?: (tokenId: number, token: string) => boolean | void;
/** Callback for prompt tokens, called for each input token in the prompt.
* @param {number} tokenId The token id.
* @returns {boolean | undefined} Whether to continue ingesting the prompt.
* */
onPromptToken?: (tokenId: number) => boolean | void;
}
/**
@ -101,14 +281,13 @@ declare class EmbeddingModel {
declare class LLModel {
/**
* Initialize a new LLModel.
* @param path Absolute path to the model file.
* @param {string} path Absolute path to the model file.
* @throws {Error} If the model file does not exist.
*/
constructor(path: string);
constructor(options: LLModelOptions);
/** either 'gpt', mpt', or 'llama' or undefined */
type(): ModelType | undefined;
/** undefined or user supplied */
type(): string | undefined;
/** The name of the model. */
name(): string;
@ -134,29 +313,53 @@ declare class LLModel {
setThreadCount(newNumber: number): void;
/**
* Prompt the model with a given input and optional parameters.
* This is the raw output from model.
* Use the prompt function exported for a value
* @param q The prompt input.
* @param params Optional parameters for the prompt context.
* @param callback - optional callback to control token generation.
* @returns The result of the model prompt.
* Prompt the model directly with a given input string and optional parameters.
* Use the higher level createCompletion methods for a more user-friendly interface.
* @param {string} prompt The prompt input.
* @param {LLModelInferenceOptions} options Optional parameters for the generation.
* @returns {LLModelInferenceResult} The response text and final context size.
*/
infer(
prompt: string,
options: LLModelInferenceOptions
): Promise<LLModelInferenceResult>;
/**
* Embed text with the model. See EmbeddingOptions for more information.
* Use the higher level createEmbedding methods for a more user-friendly interface.
* @param {string} text
* @param {string} prefix
* @param {number} dimensionality
* @param {boolean} doMean
* @param {boolean} atlas
* @returns {Float32Array} The embedding of the text.
*/
raw_prompt(
q: string,
params: Partial<LLModelPromptContext>,
callback?: TokenCallback
): Promise<string>
embed(
text: string,
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): Float32Array;
/**
* Embed text with the model. Keep in mind that
* not all models can embed text, (only bert can embed as of 07/16/2023 (mm/dd/yyyy))
* Use the prompt function exported for a value
* @param q The prompt input.
* @param params Optional parameters for the prompt context.
* @returns The result of the model prompt.
* Embed multiple texts with the model. See EmbeddingOptions for more information.
* Use the higher level createEmbedding methods for a more user-friendly interface.
* @param {string[]} texts
* @param {string} prefix
* @param {number} dimensionality
* @param {boolean} doMean
* @param {boolean} atlas
* @returns {Float32Array[]} The embeddings of the texts.
*/
embed(text: string): Float32Array;
embed(
texts: string,
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): Float32Array[];
/**
* Whether the model is loaded or not.
*/
@ -166,81 +369,97 @@ declare class LLModel {
* Where to search for the pluggable backend libraries
*/
setLibraryPath(s: string): void;
/**
* Where to get the pluggable backend libraries
*/
getLibraryPath(): string;
/**
* Initiate a GPU by a string identifier.
* @param {number} memory_required Should be in the range size_t or will throw
* Initiate a GPU by a string identifier.
* @param {number} memory_required Should be in the range size_t or will throw
* @param {string} device_name 'amd' | 'nvidia' | 'intel' | 'gpu' | gpu name.
* read LoadModelOptions.device for more information
*/
initGpuByString(memory_required: number, device_name: string): boolean
initGpuByString(memory_required: number, device_name: string): boolean;
/**
* From C documentation
* @returns True if a GPU device is successfully initialized, false otherwise.
*/
hasGpuDevice(): boolean
hasGpuDevice(): boolean;
/**
* GPUs that are usable for this LLModel
* @param nCtx Maximum size of context window
* @throws if hasGpuDevice returns false (i think)
* @returns
*/
listGpu(nCtx: number) : GpuDevice[]
* GPUs that are usable for this LLModel
* @param {number} nCtx Maximum size of context window
* @throws if hasGpuDevice returns false (i think)
* @returns
*/
listGpu(nCtx: number): GpuDevice[];
/**
* delete and cleanup the native model
* delete and cleanup the native model
*/
dispose(): void
dispose(): void;
}
/**
/**
* an object that contains gpu data on this machine.
*/
interface GpuDevice {
index: number;
/**
* same as VkPhysicalDeviceType
* same as VkPhysicalDeviceType
*/
type: number;
heapSize : number;
type: number;
heapSize: number;
name: string;
vendor: string;
}
/**
* Options that configure a model's behavior.
*/
* Options that configure a model's behavior.
*/
interface LoadModelOptions {
/**
* Where to look for model files.
*/
modelPath?: string;
/**
* Where to look for the backend libraries.
*/
librariesPath?: string;
/**
* The path to the model configuration file, useful for offline usage or custom model configurations.
*/
modelConfigFile?: string;
/**
* Whether to allow downloading the model if it is not present at the specified path.
*/
allowDownload?: boolean;
/**
* Enable verbose logging.
*/
verbose?: boolean;
/* The processing unit on which the model will run. It can be set to
/**
* The processing unit on which the model will run. It can be set to
* - "cpu": Model will run on the central processing unit.
* - "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
* - "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
Alternatively, a specific GPU name can also be provided, and the model will run on the GPU that matches the name
if it's available.
Default is "cpu".
Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All
instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the
model.
*/
* - "gpu name": Model will run on the GPU that matches the name if it's available.
* Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All
* instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the
* model.
* @default "cpu"
*/
device?: string;
/*
/**
* The Maximum window size of this model
* Default of 2048
* @default 2048
*/
nCtx?: number;
/*
/**
* Number of gpu layers needed
* Default of 100
* @default 100
*/
ngl?: number;
}
@ -277,66 +496,84 @@ declare function loadModel(
): Promise<InferenceModel | EmbeddingModel>;
/**
* The nodejs equivalent to python binding's chat_completion
* @param {InferenceModel} model - The language model object.
* @param {PromptMessage[]} messages - The array of messages for the conversation.
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {CompletionReturn} The completion result.
*/
declare function createCompletion(
model: InferenceModel,
messages: PromptMessage[],
options?: CompletionOptions
): Promise<CompletionReturn>;
/**
* The nodejs moral equivalent to python binding's Embed4All().embed()
* meow
* @param {EmbeddingModel} model - The language model object.
* @param {string} text - text to embed
* @returns {Float32Array} The completion result.
* Interface for createCompletion methods, implemented by InferenceModel and ChatSession.
* Implement your own CompletionProvider or extend ChatSession to generate completions with custom logic.
*/
declare function createEmbedding(
model: EmbeddingModel,
text: string
): Float32Array;
interface CompletionProvider {
modelName: string;
generate(
input: CompletionInput,
options?: CompletionOptions
): Promise<InferenceResult>;
}
/**
* The options for creating the completion.
* Options for creating a completion.
*/
interface CompletionOptions extends Partial<LLModelPromptContext> {
interface CompletionOptions extends LLModelInferenceOptions {
/**
* Indicates if verbose logging is enabled.
* @default true
* @default false
*/
verbose?: boolean;
}
/**
* Template for the system message. Will be put before the conversation with %1 being replaced by all system messages.
* Note that if this is not defined, system messages will not be included in the prompt.
*/
systemPromptTemplate?: string;
/**
* The input for creating a completion. May be a string or an array of messages.
*/
type CompletionInput = string | ChatMessage[];
/**
* Template for user messages, with %1 being replaced by the message.
*/
promptTemplate?: boolean;
/**
* The nodejs equivalent to python binding's chat_completion
* @param {CompletionProvider} provider - The inference model object or chat session
* @param {CompletionInput} input - The input string or message array
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {CompletionResult} The completion result.
*/
declare function createCompletion(
provider: CompletionProvider,
input: CompletionInput,
options?: CompletionOptions
): Promise<CompletionResult>;
/**
* The initial instruction for the model, on top of the prompt
*/
promptHeader?: string;
/**
* Streaming variant of createCompletion, returns a stream of tokens and a promise that resolves to the completion result.
* @param {CompletionProvider} provider - The inference model object or chat session
* @param {CompletionInput} input - The input string or message array
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {CompletionStreamReturn} An object of token stream and the completion result promise.
*/
declare function createCompletionStream(
provider: CompletionProvider,
input: CompletionInput,
options?: CompletionOptions
): CompletionStreamReturn;
/**
* The last instruction for the model, appended to the end of the prompt.
*/
promptFooter?: string;
/**
* The result of a streamed completion, containing a stream of tokens and a promise that resolves to the completion result.
*/
interface CompletionStreamReturn {
tokens: NodeJS.ReadableStream;
result: Promise<CompletionResult>;
}
/**
* A message in the conversation, identical to OpenAI's chat message.
* Async generator variant of createCompletion, yields tokens as they are generated and returns the completion result.
* @param {CompletionProvider} provider - The inference model object or chat session
* @param {CompletionInput} input - The input string or message array
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {AsyncGenerator<string>} The stream of generated tokens
*/
declare function createCompletionGenerator(
provider: CompletionProvider,
input: CompletionInput,
options: CompletionOptions
): AsyncGenerator<string, CompletionResult>;
/**
* A message in the conversation.
*/
interface PromptMessage {
interface ChatMessage {
/** The role of the message. */
role: "system" | "assistant" | "user";
@ -345,34 +582,31 @@ interface PromptMessage {
}
/**
* The result of the completion, similar to OpenAI's format.
* The result of a completion.
*/
interface CompletionReturn {
interface CompletionResult {
/** The model used for the completion. */
model: string;
/** Token usage report. */
usage: {
/** The number of tokens used in the prompt. */
/** The number of tokens ingested during the completion. */
prompt_tokens: number;
/** The number of tokens used in the completion. */
/** The number of tokens generated in the completion. */
completion_tokens: number;
/** The total number of tokens used. */
total_tokens: number;
};
/** The generated completions. */
choices: CompletionChoice[];
}
/** Number of tokens used in the conversation. */
n_past_tokens: number;
};
/**
* A completion choice, similar to OpenAI's format.
*/
interface CompletionChoice {
/** Response message */
message: PromptMessage;
/** The generated completion. */
choices: Array<{
message: ChatMessage;
}>;
}
/**
@ -385,19 +619,33 @@ interface LLModelPromptContext {
/** The size of the raw tokens vector. */
tokensSize: number;
/** The number of tokens in the past conversation. */
/** The number of tokens in the past conversation.
* This may be used to "roll back" the conversation to a previous state.
* Note that for most use cases the default value should be sufficient and this should not be set.
* @default 0 For completions using InferenceModel, meaning the model will only consider the input prompt.
* @default nPast For completions using ChatSession. This means the context window will be automatically determined
* and possibly resized (see contextErase) to keep the conversation performant.
* */
nPast: number;
/** The number of tokens possible in the context window.
* @default 1024
*/
nCtx: number;
/** The number of tokens to predict.
* @default 128
/** The maximum number of tokens to predict.
* @default 4096
* */
nPredict: number;
/** Template for user / assistant message pairs.
* %1 is required and will be replaced by the user input.
* %2 is optional and will be replaced by the assistant response. If not present, the assistant response will be appended.
*/
promptTemplate?: string;
/** The context window size. Do not use, it has no effect. See loadModel options.
* THIS IS DEPRECATED!!!
* Use loadModel's nCtx option instead.
* @default 2048
*/
nCtx: number;
/** The top-k logits to sample from.
* Top-K sampling selects the next token only from the top K most likely tokens predicted by the model.
* It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit
@ -409,26 +657,33 @@ interface LLModelPromptContext {
topK: number;
/** The nucleus sampling probability threshold.
* Top-P limits the selection of the next token to a subset of tokens with a cumulative probability
* Top-P limits the selection of the next token to a subset of tokens with a cumulative probability
* above a threshold P. This method, also known as nucleus sampling, finds a balance between diversity
* and quality by considering both token probabilities and the number of tokens available for sampling.
* When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse.
* On the other hand, a lower value (eg., 0.1) produces more focused and conservative text.
* The default value is 0.4, which is aimed to be the middle ground between focus and diversity, but
* for more creative tasks a higher top-p value will be beneficial, about 0.5-0.9 is a good range for that.
* @default 0.4
* @default 0.9
*
* */
topP: number;
/**
* The minimum probability of a token to be considered.
* @default 0.0
*/
minP: number;
/** The temperature to adjust the model's output distribution.
* Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures
* (eg., 1.2) increase randomness, resulting in more imaginative and diverse text. Lower temperatures (eg., 0.5)
* make the output more focused, predictable, and conservative. When the temperature is set to 0, the output
* becomes completely deterministic, always selecting the most probable next token and producing identical results
* each time. A safe range would be around 0.6 - 0.85, but you are free to search what value fits best for you.
* @default 0.7
* each time. Try what value fits best for your use case and model.
* @default 0.1
* @alias temperature
* */
temp: number;
temperature: number;
/** The number of predictions to generate in parallel.
* By splitting the prompt every N tokens, prompt-batch-size reduces RAM usage during processing. However,
@ -451,31 +706,17 @@ interface LLModelPromptContext {
* The repeat-penalty-tokens N option controls the number of tokens in the history to consider for penalizing repetition.
* A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only
* consider recent tokens.
* @default 64
* @default 10
* */
repeatLastN: number;
/** The percentage of context to erase if the context window is exceeded.
* @default 0.5
* Set it to a lower value to keep context for longer at the cost of performance.
* @default 0.75
* */
contextErase: number;
}
/**
* Creates an async generator of tokens
* @param {InferenceModel} llmodel - The language model object.
* @param {PromptMessage[]} messages - The array of messages for the conversation.
* @param {CompletionOptions} options - The options for creating the completion.
* @param {TokenCallback} callback - optional callback to control token generation.
* @returns {AsyncGenerator<string>} The stream of generated tokens
*/
declare function generateTokens(
llmodel: InferenceModel,
messages: PromptMessage[],
options: CompletionOptions,
callback?: TokenCallback
): AsyncGenerator<string>;
/**
* From python api:
* models will be stored in (homedir)/.cache/gpt4all/`
@ -508,7 +749,7 @@ declare const DEFAULT_MODEL_LIST_URL: string;
* Initiates the download of a model file.
* By default this downloads without waiting. use the controller returned to alter this behavior.
* @param {string} modelName - The model to be downloaded.
* @param {DownloadOptions} options - to pass into the downloader. Default is { location: (cwd), verbose: false }.
* @param {DownloadModelOptions} options - to pass into the downloader. Default is { location: (cwd), verbose: false }.
* @returns {DownloadController} object that allows controlling the download process.
*
* @throws {Error} If the model already exists in the specified location.
@ -556,7 +797,9 @@ interface ListModelsOptions {
file?: string;
}
declare function listModels(options?: ListModelsOptions): Promise<ModelConfig[]>;
declare function listModels(
options?: ListModelsOptions
): Promise<ModelConfig[]>;
interface RetrieveModelOptions {
allowDownload?: boolean;
@ -581,30 +824,35 @@ interface DownloadController {
}
export {
ModelType,
ModelFile,
LLModel,
LLModelPromptContext,
ModelConfig,
InferenceModel,
InferenceResult,
EmbeddingModel,
LLModel,
LLModelPromptContext,
PromptMessage,
EmbeddingResult,
ChatSession,
ChatMessage,
CompletionInput,
CompletionProvider,
CompletionOptions,
CompletionResult,
LoadModelOptions,
DownloadController,
RetrieveModelOptions,
DownloadModelOptions,
GpuDevice,
loadModel,
downloadModel,
retrieveModel,
listModels,
createCompletion,
createCompletionStream,
createCompletionGenerator,
createEmbedding,
generateTokens,
DEFAULT_DIRECTORY,
DEFAULT_LIBRARIES_DIRECTORY,
DEFAULT_MODEL_CONFIG,
DEFAULT_PROMPT_CONTEXT,
DEFAULT_MODEL_LIST_URL,
downloadModel,
retrieveModel,
listModels,
DownloadController,
RetrieveModelOptions,
DownloadModelOptions,
GpuDevice
};

@ -2,8 +2,10 @@
/// This file implements the gpt4all.d.ts file endings.
/// Written in commonjs to support both ESM and CJS projects.
const { existsSync } = require("fs");
const { existsSync } = require("node:fs");
const path = require("node:path");
const Stream = require("node:stream");
const assert = require("node:assert");
const { LLModel } = require("node-gyp-build")(path.resolve(__dirname, ".."));
const {
retrieveModel,
@ -18,15 +20,14 @@ const {
DEFAULT_MODEL_LIST_URL,
} = require("./config.js");
const { InferenceModel, EmbeddingModel } = require("./models.js");
const Stream = require('stream')
const assert = require("assert");
const { ChatSession } = require("./chat-session.js");
/**
* Loads a machine learning model with the specified name. The defacto way to create a model.
* By default this will download a model from the official GPT4ALL website, if a model is not present at given path.
*
* @param {string} modelName - The name of the model to load.
* @param {LoadModelOptions|undefined} [options] - (Optional) Additional options for loading the model.
* @param {import('./gpt4all').LoadModelOptions|undefined} [options] - (Optional) Additional options for loading the model.
* @returns {Promise<InferenceModel | EmbeddingModel>} A promise that resolves to an instance of the loaded LLModel.
*/
async function loadModel(modelName, options = {}) {
@ -35,10 +36,10 @@ async function loadModel(modelName, options = {}) {
librariesPath: DEFAULT_LIBRARIES_DIRECTORY,
type: "inference",
allowDownload: true,
verbose: true,
device: 'cpu',
verbose: false,
device: "cpu",
nCtx: 2048,
ngl : 100,
ngl: 100,
...options,
};
@ -49,12 +50,14 @@ async function loadModel(modelName, options = {}) {
verbose: loadOptions.verbose,
});
assert.ok(typeof loadOptions.librariesPath === 'string');
assert.ok(
typeof loadOptions.librariesPath === "string",
"Libraries path should be a string"
);
const existingPaths = loadOptions.librariesPath
.split(";")
.filter(existsSync)
.join(';');
console.log("Passing these paths into runtime library search:", existingPaths)
.join(";");
const llmOptions = {
model_name: appendBinSuffixIfMissing(modelName),
@ -62,13 +65,15 @@ async function loadModel(modelName, options = {}) {
library_path: existingPaths,
device: loadOptions.device,
nCtx: loadOptions.nCtx,
ngl: loadOptions.ngl
ngl: loadOptions.ngl,
};
if (loadOptions.verbose) {
console.debug("Creating LLModel with options:", llmOptions);
console.debug("Creating LLModel:", {
llmOptions,
modelConfig,
});
}
console.log(modelConfig)
const llmodel = new LLModel(llmOptions);
if (loadOptions.type === "embedding") {
return new EmbeddingModel(llmodel, modelConfig);
@ -79,75 +84,43 @@ async function loadModel(modelName, options = {}) {
}
}
/**
* Formats a list of messages into a single prompt string.
*/
function formatChatPrompt(
messages,
{
systemPromptTemplate,
defaultSystemPrompt,
promptTemplate,
promptFooter,
promptHeader,
}
) {
const systemMessages = messages
.filter((message) => message.role === "system")
.map((message) => message.content);
let fullPrompt = "";
if (promptHeader) {
fullPrompt += promptHeader + "\n\n";
}
if (systemPromptTemplate) {
// if user specified a template for the system prompt, put all system messages in the template
let systemPrompt = "";
if (systemMessages.length > 0) {
systemPrompt += systemMessages.join("\n");
}
function createEmbedding(model, text, options={}) {
let {
dimensionality = undefined,
longTextMode = "mean",
atlas = false,
} = options;
if (systemPrompt) {
fullPrompt +=
systemPromptTemplate.replace("%1", systemPrompt) + "\n";
}
} else if (defaultSystemPrompt) {
// otherwise, use the system prompt from the model config and ignore system messages
fullPrompt += defaultSystemPrompt + "\n\n";
}
if (systemMessages.length > 0 && !systemPromptTemplate) {
console.warn(
"System messages were provided, but no systemPromptTemplate was specified. System messages will be ignored."
);
}
for (const message of messages) {
if (message.role === "user") {
const userMessage = promptTemplate.replace(
"%1",
message["content"]
if (dimensionality === undefined) {
dimensionality = -1;
} else {
if (dimensionality <= 0) {
throw new Error(
`Dimensionality must be undefined or a positive integer, got ${dimensionality}`
);
fullPrompt += userMessage;
}
if (message["role"] == "assistant") {
const assistantMessage = message["content"] + "\n";
fullPrompt += assistantMessage;
if (dimensionality < model.MIN_DIMENSIONALITY) {
console.warn(
`Dimensionality ${dimensionality} is less than the suggested minimum of ${model.MIN_DIMENSIONALITY}. Performance may be degraded.`
);
}
}
if (promptFooter) {
fullPrompt += "\n\n" + promptFooter;
let doMean;
switch (longTextMode) {
case "mean":
doMean = true;
break;
case "truncate":
doMean = false;
break;
default:
throw new Error(
`Long text mode must be one of 'mean' or 'truncate', got ${longTextMode}`
);
}
return fullPrompt;
}
function createEmbedding(model, text) {
return model.embed(text);
return model.embed(text, options?.prefix, dimensionality, doMean, atlas);
}
const defaultCompletionOptions = {
@ -155,162 +128,76 @@ const defaultCompletionOptions = {
...DEFAULT_PROMPT_CONTEXT,
};
function preparePromptAndContext(model,messages,options){
if (options.hasDefaultHeader !== undefined) {
console.warn(
"hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead"
);
}
if (options.hasDefaultFooter !== undefined) {
console.warn(
"hasDefaultFooter (bool) is deprecated and has no effect, use promptFooter (string) instead"
);
}
const optionsWithDefaults = {
...defaultCompletionOptions,
...options,
};
const {
verbose,
systemPromptTemplate,
promptTemplate,
promptHeader,
promptFooter,
...promptContext
} = optionsWithDefaults;
const prompt = formatChatPrompt(messages, {
systemPromptTemplate,
defaultSystemPrompt: model.config.systemPrompt,
promptTemplate: promptTemplate || model.config.promptTemplate || "%1",
promptHeader: promptHeader || "",
promptFooter: promptFooter || "",
// These were the default header/footer prompts used for non-chat single turn completions.
// both seem to be working well still with some models, so keeping them here for reference.
// promptHeader: '### Instruction: The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.',
// promptFooter: '### Response:',
});
return {
prompt, promptContext, verbose
}
}
async function createCompletion(
model,
messages,
provider,
input,
options = defaultCompletionOptions
) {
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
if (verbose) {
console.debug("Sending Prompt:\n" + prompt);
}
let tokensGenerated = 0
const response = await model.generate(prompt, promptContext,() => {
tokensGenerated++;
return true;
});
const completionOptions = {
...defaultCompletionOptions,
...options,
};
if (verbose) {
console.debug("Received Response:\n" + response);
}
const result = await provider.generate(
input,
completionOptions,
);
return {
llmodel: model.llm.name(),
model: provider.modelName,
usage: {
prompt_tokens: prompt.length,
completion_tokens: tokensGenerated,
total_tokens: prompt.length + tokensGenerated, //TODO Not sure how to get tokens in prompt
prompt_tokens: result.tokensIngested,
total_tokens: result.tokensIngested + result.tokensGenerated,
completion_tokens: result.tokensGenerated,
n_past_tokens: result.nPast,
},
choices: [
{
message: {
role: "assistant",
content: response,
content: result.text,
},
// TODO some completion APIs also provide logprobs and finish_reason, could look into adding those
},
],
};
}
function _internal_createTokenStream(stream,model,
messages,
options = defaultCompletionOptions,callback = undefined) {
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
if (verbose) {
console.debug("Sending Prompt:\n" + prompt);
}
model.generate(prompt, promptContext,(tokenId, token, total) => {
stream.push(token);
if(callback !== undefined){
return callback(tokenId,token,total);
}
return true;
}).then(() => {
stream.end()
})
return stream;
}
function _createTokenStream(model,
messages,
options = defaultCompletionOptions,callback = undefined) {
// Silent crash if we dont do this here
const stream = new Stream.PassThrough({
encoding: 'utf-8'
});
return _internal_createTokenStream(stream,model,messages,options,callback);
}
async function* generateTokens(model,
messages,
options = defaultCompletionOptions, callback = undefined) {
const stream = _createTokenStream(model,messages,options,callback)
let bHasFinished = false;
let activeDataCallback = undefined;
const finishCallback = () => {
bHasFinished = true;
if(activeDataCallback !== undefined){
activeDataCallback(undefined);
}
}
stream.on("finish",finishCallback)
while (!bHasFinished) {
const token = await new Promise((res) => {
function createCompletionStream(
provider,
input,
options = defaultCompletionOptions
) {
const completionStream = new Stream.PassThrough({
encoding: "utf-8",
});
activeDataCallback = (d) => {
stream.off("data",activeDataCallback)
activeDataCallback = undefined
res(d);
const completionPromise = createCompletion(provider, input, {
...options,
onResponseToken: (tokenId, token) => {
completionStream.push(token);
if (options.onResponseToken) {
return options.onResponseToken(tokenId, token);
}
stream.on('data', activeDataCallback)
})
},
}).then((result) => {
completionStream.push(null);
completionStream.emit("end");
return result;
});
if (token == undefined) {
break;
}
return {
tokens: completionStream,
result: completionPromise,
};
}
yield token;
async function* createCompletionGenerator(provider, input, options) {
const completion = createCompletionStream(provider, input, options);
for await (const chunk of completion.tokens) {
yield chunk;
}
stream.off("finish",finishCallback);
return await completion.result;
}
module.exports = {
@ -322,10 +209,12 @@ module.exports = {
LLModel,
InferenceModel,
EmbeddingModel,
ChatSession,
createCompletion,
createCompletionStream,
createCompletionGenerator,
createEmbedding,
downloadModel,
retrieveModel,
loadModel,
generateTokens
};

@ -1,18 +1,138 @@
const { normalizePromptContext, warnOnSnakeCaseKeys } = require('./util');
const { DEFAULT_PROMPT_CONTEXT } = require("./config");
const { ChatSession } = require("./chat-session");
const { prepareMessagesForIngest } = require("./util");
class InferenceModel {
llm;
modelName;
config;
activeChatSession;
constructor(llmodel, config) {
this.llm = llmodel;
this.config = config;
this.modelName = this.llm.name();
}
async generate(prompt, promptContext,callback) {
warnOnSnakeCaseKeys(promptContext);
const normalizedPromptContext = normalizePromptContext(promptContext);
const result = this.llm.raw_prompt(prompt, normalizedPromptContext,callback);
async createChatSession(options) {
const chatSession = new ChatSession(this, options);
await chatSession.initialize();
this.activeChatSession = chatSession;
return this.activeChatSession;
}
async generate(input, options = DEFAULT_PROMPT_CONTEXT) {
const { verbose, ...otherOptions } = options;
const promptContext = {
promptTemplate: this.config.promptTemplate,
temp:
otherOptions.temp ??
otherOptions.temperature ??
DEFAULT_PROMPT_CONTEXT.temp,
...otherOptions,
};
if (promptContext.nPast < 0) {
throw new Error("nPast must be a non-negative integer.");
}
if (verbose) {
console.debug("Generating completion", {
input,
promptContext,
});
}
let prompt = input;
let nPast = promptContext.nPast;
let tokensIngested = 0;
if (Array.isArray(input)) {
// assuming input is a messages array
// -> tailing user message will be used as the final prompt. its required.
// -> leading system message will be ingested as systemPrompt, further system messages will be ignored
// -> all other messages will be ingested with fakeReply
// -> model/context will only be kept for this completion; "stateless"
nPast = 0;
const messages = [...input];
const lastMessage = input[input.length - 1];
if (lastMessage.role !== "user") {
// this is most likely a user error
throw new Error("The final message must be of role 'user'.");
}
if (input[0].role === "system") {
// needs to be a pre-templated prompt ala '<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n'
const systemPrompt = input[0].content;
const systemRes = await this.llm.infer(systemPrompt, {
promptTemplate: "%1",
nPredict: 0,
special: true,
});
nPast = systemRes.nPast;
tokensIngested += systemRes.tokensIngested;
messages.shift();
}
prompt = lastMessage.content;
const messagesToIngest = messages.slice(0, input.length - 1);
const turns = prepareMessagesForIngest(messagesToIngest);
for (const turn of turns) {
const turnRes = await this.llm.infer(turn.user, {
...promptContext,
nPast,
fakeReply: turn.assistant,
});
tokensIngested += turnRes.tokensIngested;
nPast = turnRes.nPast;
}
}
let tokensGenerated = 0;
const result = await this.llm.infer(prompt, {
...promptContext,
nPast,
onPromptToken: (tokenId) => {
let continueIngestion = true;
tokensIngested++;
if (options.onPromptToken) {
// catch errors because if they go through cpp they will loose stacktraces
try {
// don't cancel ingestion unless user explicitly returns false
continueIngestion =
options.onPromptToken(tokenId) !== false;
} catch (e) {
console.error("Error in onPromptToken callback", e);
continueIngestion = false;
}
}
return continueIngestion;
},
onResponseToken: (tokenId, token) => {
let continueGeneration = true;
tokensGenerated++;
if (options.onResponseToken) {
try {
// don't cancel the generation unless user explicitly returns false
continueGeneration =
options.onResponseToken(tokenId, token) !== false;
} catch (err) {
console.error("Error in onResponseToken callback", err);
continueGeneration = false;
}
}
return continueGeneration;
},
});
result.tokensGenerated = tokensGenerated;
result.tokensIngested = tokensIngested;
if (verbose) {
console.debug("Finished completion:\n", result);
}
return result;
}
@ -24,14 +144,14 @@ class InferenceModel {
class EmbeddingModel {
llm;
config;
MIN_DIMENSIONALITY = 64;
constructor(llmodel, config) {
this.llm = llmodel;
this.config = config;
}
embed(text) {
return this.llm.embed(text)
embed(text, prefix, dimensionality, do_mean, atlas) {
return this.llm.embed(text, prefix, dimensionality, do_mean, atlas);
}
dispose() {
@ -39,7 +159,6 @@ class EmbeddingModel {
}
}
module.exports = {
InferenceModel,
EmbeddingModel,

@ -1,8 +1,7 @@
const { createWriteStream, existsSync, statSync } = require("node:fs");
const { createWriteStream, existsSync, statSync, mkdirSync } = require("node:fs");
const fsp = require("node:fs/promises");
const { performance } = require("node:perf_hooks");
const path = require("node:path");
const { mkdirp } = require("mkdirp");
const md5File = require("md5-file");
const {
DEFAULT_DIRECTORY,
@ -50,6 +49,63 @@ function appendBinSuffixIfMissing(name) {
return name;
}
function prepareMessagesForIngest(messages) {
const systemMessages = messages.filter(
(message) => message.role === "system"
);
if (systemMessages.length > 0) {
console.warn(
"System messages are currently not supported and will be ignored. Use the systemPrompt option instead."
);
}
const userAssistantMessages = messages.filter(
(message) => message.role !== "system"
);
// make sure the first message is a user message
// if its not, the turns will be out of order
if (userAssistantMessages[0].role !== "user") {
userAssistantMessages.unshift({
role: "user",
content: "",
});
}
// create turns of user input + assistant reply
const turns = [];
let userMessage = null;
let assistantMessage = null;
for (const message of userAssistantMessages) {
// consecutive messages of the same role are concatenated into one message
if (message.role === "user") {
if (!userMessage) {
userMessage = message.content;
} else {
userMessage += "\n" + message.content;
}
} else if (message.role === "assistant") {
if (!assistantMessage) {
assistantMessage = message.content;
} else {
assistantMessage += "\n" + message.content;
}
}
if (userMessage && assistantMessage) {
turns.push({
user: userMessage,
assistant: assistantMessage,
});
userMessage = null;
assistantMessage = null;
}
}
return turns;
}
// readChunks() reads from the provided reader and yields the results into an async iterable
// https://css-tricks.com/web-streams-everywhere-and-fetch-for-node-js/
function readChunks(reader) {
@ -64,49 +120,13 @@ function readChunks(reader) {
};
}
/**
* Prints a warning if any keys in the prompt context are snake_case.
*/
function warnOnSnakeCaseKeys(promptContext) {
const snakeCaseKeys = Object.keys(promptContext).filter((key) =>
key.includes("_")
);
if (snakeCaseKeys.length > 0) {
console.warn(
"Prompt context keys should be camelCase. Support for snake_case might be removed in the future. Found keys: " +
snakeCaseKeys.join(", ")
);
}
}
/**
* Converts all keys in the prompt context to snake_case
* For duplicate definitions, the value of the last occurrence will be used.
*/
function normalizePromptContext(promptContext) {
const normalizedPromptContext = {};
for (const key in promptContext) {
if (promptContext.hasOwnProperty(key)) {
const snakeKey = key.replace(
/[A-Z]/g,
(match) => `_${match.toLowerCase()}`
);
normalizedPromptContext[snakeKey] = promptContext[key];
}
}
return normalizedPromptContext;
}
function downloadModel(modelName, options = {}) {
const downloadOptions = {
modelPath: DEFAULT_DIRECTORY,
verbose: false,
...options,
};
const modelFileName = appendBinSuffixIfMissing(modelName);
const partialModelPath = path.join(
downloadOptions.modelPath,
@ -114,16 +134,17 @@ function downloadModel(modelName, options = {}) {
);
const finalModelPath = path.join(downloadOptions.modelPath, modelFileName);
const modelUrl =
downloadOptions.url ?? `https://gpt4all.io/models/gguf/${modelFileName}`;
downloadOptions.url ??
`https://gpt4all.io/models/gguf/${modelFileName}`;
mkdirp.sync(downloadOptions.modelPath)
mkdirSync(downloadOptions.modelPath, { recursive: true });
if (existsSync(finalModelPath)) {
throw Error(`Model already exists at ${finalModelPath}`);
}
if (downloadOptions.verbose) {
console.log(`Downloading ${modelName} from ${modelUrl}`);
console.debug(`Downloading ${modelName} from ${modelUrl}`);
}
const headers = {
@ -134,7 +155,9 @@ function downloadModel(modelName, options = {}) {
const writeStreamOpts = {};
if (existsSync(partialModelPath)) {
console.log("Partial model exists, resuming download...");
if (downloadOptions.verbose) {
console.debug("Partial model exists, resuming download...");
}
const startRange = statSync(partialModelPath).size;
headers["Range"] = `bytes=${startRange}-`;
writeStreamOpts.flags = "a";
@ -144,15 +167,15 @@ function downloadModel(modelName, options = {}) {
const signal = abortController.signal;
const finalizeDownload = async () => {
if (options.md5sum) {
if (downloadOptions.md5sum) {
const fileHash = await md5File(partialModelPath);
if (fileHash !== options.md5sum) {
if (fileHash !== downloadOptions.md5sum) {
await fsp.unlink(partialModelPath);
const message = `Model "${modelName}" failed verification: Hashes mismatch. Expected ${options.md5sum}, got ${fileHash}`;
const message = `Model "${modelName}" failed verification: Hashes mismatch. Expected ${downloadOptions.md5sum}, got ${fileHash}`;
throw Error(message);
}
if (options.verbose) {
console.log(`MD5 hash verified: ${fileHash}`);
if (downloadOptions.verbose) {
console.debug(`MD5 hash verified: ${fileHash}`);
}
}
@ -163,8 +186,8 @@ function downloadModel(modelName, options = {}) {
const downloadPromise = new Promise((resolve, reject) => {
let timestampStart;
if (options.verbose) {
console.log(`Downloading @ ${partialModelPath} ...`);
if (downloadOptions.verbose) {
console.debug(`Downloading @ ${partialModelPath} ...`);
timestampStart = performance.now();
}
@ -179,7 +202,7 @@ function downloadModel(modelName, options = {}) {
});
writeStream.on("finish", () => {
if (options.verbose) {
if (downloadOptions.verbose) {
const elapsed = performance.now() - timestampStart;
console.log(`Finished. Download took ${elapsed.toFixed(2)} ms`);
}
@ -221,10 +244,10 @@ async function retrieveModel(modelName, options = {}) {
const retrieveOptions = {
modelPath: DEFAULT_DIRECTORY,
allowDownload: true,
verbose: true,
verbose: false,
...options,
};
await mkdirp(retrieveOptions.modelPath);
mkdirSync(retrieveOptions.modelPath, { recursive: true });
const modelFileName = appendBinSuffixIfMissing(modelName);
const fullModelPath = path.join(retrieveOptions.modelPath, modelFileName);
@ -236,7 +259,7 @@ async function retrieveModel(modelName, options = {}) {
file: retrieveOptions.modelConfigFile,
url:
retrieveOptions.allowDownload &&
"https://gpt4all.io/models/models2.json",
"https://gpt4all.io/models/models3.json",
});
const loadedModelConfig = availableModels.find(
@ -262,10 +285,9 @@ async function retrieveModel(modelName, options = {}) {
config.path = fullModelPath;
if (retrieveOptions.verbose) {
console.log(`Found ${modelName} at ${fullModelPath}`);
console.debug(`Found ${modelName} at ${fullModelPath}`);
}
} else if (retrieveOptions.allowDownload) {
const downloadController = downloadModel(modelName, {
modelPath: retrieveOptions.modelPath,
verbose: retrieveOptions.verbose,
@ -278,7 +300,7 @@ async function retrieveModel(modelName, options = {}) {
config.path = downloadPath;
if (retrieveOptions.verbose) {
console.log(`Model downloaded to ${downloadPath}`);
console.debug(`Model downloaded to ${downloadPath}`);
}
} else {
throw Error("Failed to retrieve model.");
@ -288,9 +310,8 @@ async function retrieveModel(modelName, options = {}) {
module.exports = {
appendBinSuffixIfMissing,
prepareMessagesForIngest,
downloadModel,
retrieveModel,
listModels,
normalizePromptContext,
warnOnSnakeCaseKeys,
};

@ -7,7 +7,6 @@ const {
listModels,
downloadModel,
appendBinSuffixIfMissing,
normalizePromptContext,
} = require("../src/util.js");
const {
DEFAULT_DIRECTORY,
@ -19,8 +18,6 @@ const {
createPrompt,
createCompletion,
} = require("../src/gpt4all.js");
const { mock } = require("node:test");
const { mkdirp } = require("mkdirp");
describe("config", () => {
test("default paths constants are available and correct", () => {
@ -87,7 +84,7 @@ describe("listModels", () => {
expect(fetch).toHaveBeenCalledTimes(0);
expect(models[0]).toEqual(fakeModel);
});
it("should throw an error if neither url nor file is specified", async () => {
await expect(listModels(null)).rejects.toThrow(
"No model list source specified. Please specify either a url or a file."
@ -141,10 +138,10 @@ describe("downloadModel", () => {
mockAbortController.mockReset();
mockFetch.mockClear();
global.fetch.mockRestore();
const rootDefaultPath = path.resolve(DEFAULT_DIRECTORY),
partialPath = path.resolve(rootDefaultPath, fakeModelName+'.part'),
fullPath = path.resolve(rootDefaultPath, fakeModelName+'.bin')
fullPath = path.resolve(rootDefaultPath, fakeModelName+'.bin')
//if tests fail, remove the created files
// acts as cleanup if tests fail
@ -206,46 +203,3 @@ describe("downloadModel", () => {
// test("should be able to cancel and resume a download", async () => {
// });
});
describe("normalizePromptContext", () => {
it("should convert a dict with camelCased keys to snake_case", () => {
const camelCased = {
topK: 20,
repeatLastN: 10,
};
const expectedSnakeCased = {
top_k: 20,
repeat_last_n: 10,
};
const result = normalizePromptContext(camelCased);
expect(result).toEqual(expectedSnakeCased);
});
it("should convert a mixed case dict to snake_case, last value taking precedence", () => {
const mixedCased = {
topK: 20,
top_k: 10,
repeatLastN: 10,
};
const expectedSnakeCased = {
top_k: 10,
repeat_last_n: 10,
};
const result = normalizePromptContext(mixedCased);
expect(result).toEqual(expectedSnakeCased);
});
it("should not modify already snake cased dict", () => {
const snakeCased = {
top_k: 10,
repeast_last_n: 10,
};
const result = normalizePromptContext(snakeCased);
expect(result).toEqual(snakeCased);
});
});

@ -2300,7 +2300,6 @@ __metadata:
documentation: ^14.0.2
jest: ^29.5.0
md5-file: ^5.0.0
mkdirp: ^3.0.1
node-addon-api: ^6.1.0
node-gyp: 9.x.x
node-gyp-build: ^4.6.0
@ -4258,15 +4257,6 @@ __metadata:
languageName: node
linkType: hard
"mkdirp@npm:^3.0.1":
version: 3.0.1
resolution: "mkdirp@npm:3.0.1"
bin:
mkdirp: dist/cjs/src/bin.js
checksum: 972deb188e8fb55547f1e58d66bd6b4a3623bf0c7137802582602d73e6480c1c2268dcbafbfb1be466e00cc7e56ac514d7fd9334b7cf33e3e2ab547c16f83a8d
languageName: node
linkType: hard
"mri@npm:^1.1.0":
version: 1.2.0
resolution: "mri@npm:1.2.0"

Loading…
Cancel
Save