Skip to content

Commit b10d96f

Browse files
committed
Refactor and make use of jinja templates.
Signed-off-by: Adam Treat <[email protected]>
1 parent 0d56401 commit b10d96f

14 files changed

+245
-121
lines changed

gpt4all-chat/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ else()
310310
PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf)
311311
endif()
312312
target_link_libraries(chat
313-
PRIVATE llmodel)
313+
PRIVATE llmodel jinja2cpp)
314314

315315

316316
# -- install --

gpt4all-chat/bravesearch.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,56 @@ QString BraveSearch::run(const QJsonObject &parameters, qint64 timeout)
4242
return worker.response();
4343
}
4444

45+
QJsonObject BraveSearch::paramSchema() const
46+
{
47+
static const QString braveParamSchema = R"({
48+
"apiKey": {
49+
"type": "string",
50+
"description": "The api key to use",
51+
"required": true,
52+
"modelGenerated": false,
53+
"userConfigured": true
54+
},
55+
"query": {
56+
"type": "string",
57+
"description": "The query to search",
58+
"required": true
59+
},
60+
"count": {
61+
"type": "integer",
62+
"description": "The number of excerpts to return",
63+
"required": true,
64+
"modelGenerated": false
65+
}
66+
})";
67+
68+
static const QJsonDocument braveJsonDoc = QJsonDocument::fromJson(braveParamSchema.toUtf8());
69+
Q_ASSERT(!braveJsonDoc.isNull() && braveJsonDoc.isObject());
70+
return braveJsonDoc.object();
71+
}
72+
73+
QJsonObject BraveSearch::exampleParams() const
74+
{
75+
static const QString example = R"({
76+
"query": "the 44th president of the United States"
77+
})";
78+
static const QJsonDocument exampleDoc = QJsonDocument::fromJson(example.toUtf8());
79+
Q_ASSERT(!exampleDoc.isNull() && exampleDoc.isObject());
80+
return exampleDoc.object();
81+
}
82+
83+
bool BraveSearch::isEnabled() const
84+
{
85+
// FIXME: Refer to mysettings
86+
return true;
87+
}
88+
89+
bool BraveSearch::forceUsage() const
90+
{
91+
// FIXME: Refer to mysettings
92+
return false;
93+
}
94+
4595
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count)
4696
{
4797
// Documentation on the brave web search:

gpt4all-chat/bravesearch.h

+10
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ class BraveSearch : public Tool {
4848
ToolEnums::Error error() const override { return m_error; }
4949
QString errorString() const override { return m_errorString; }
5050

51+
QString name() const override { return tr("Brave web search"); }
52+
QString description() const override { return tr("Search the web using brave"); }
53+
QString function() const override { return "brave_search"; }
54+
QJsonObject paramSchema() const override;
55+
QJsonObject exampleParams() const override;
56+
bool isEnabled() const override;
57+
bool isBuiltin() const override { return true; }
58+
bool forceUsage() const override;
59+
bool excerpts() const override { return true; }
60+
5161
private:
5262
ToolEnums::Error m_error;
5363
QString m_errorString;

gpt4all-chat/chatllm.cpp

+32-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include "localdocssearch.h"
77
#include "mysettings.h"
88
#include "network.h"
9+
#include "tool.h"
10+
#include "toolmodel.h"
911

1012
#include <QDataStream>
1113
#include <QDebug>
@@ -29,6 +31,7 @@
2931
#include <cmath>
3032
#include <cstddef>
3133
#include <functional>
34+
#include <jinja2cpp/template.h>
3235
#include <limits>
3336
#include <optional>
3437
#include <string_view>
@@ -1332,7 +1335,35 @@ void ChatLLM::processSystemPrompt()
13321335
if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText || m_isServer)
13331336
return;
13341337

1335-
const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString();
1338+
const std::string systemPromptTemplate = MySettings::globalInstance()->modelSystemPromptTemplate(m_modelInfo).toStdString();
1339+
1340+
// FIXME: This needs to be moved to settings probably and the same code used for validation
1341+
jinja2::ValuesMap params;
1342+
params.insert({"currentDate", QDate::currentDate().toString().toStdString()});
1343+
1344+
jinja2::ValuesList toolList;
1345+
int c = ToolModel::globalInstance()->count();
1346+
for (int i = 0; i < c; ++i) {
1347+
Tool *t = ToolModel::globalInstance()->get(i);
1348+
if (t->isEnabled() && !t->forceUsage())
1349+
toolList.push_back(t->jinjaValue());
1350+
}
1351+
params.insert({"toolList", toolList});
1352+
1353+
std::string systemPrompt;
1354+
1355+
jinja2::Template t;
1356+
t.Load(systemPromptTemplate);
1357+
const auto renderResult = t.RenderAsString(params);
1358+
1359+
// The GUI should not allow setting an improper template, but it is always possible someone hand
1360+
// edits the settings file to produce an improper one.
1361+
Q_ASSERT(renderResult);
1362+
if (renderResult)
1363+
systemPrompt = renderResult.value();
1364+
else
1365+
qWarning() << "ERROR: Could not parse system prompt template:" << renderResult.error().ToString();
1366+
13361367
if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) {
13371368
m_processedSystemPrompt = true;
13381369
return;

gpt4all-chat/localdocssearch.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <QDebug>
77
#include <QGuiApplication>
88
#include <QJsonArray>
9+
#include <QJsonDocument>
910
#include <QJsonObject>
1011
#include <QThread>
1112

@@ -39,6 +40,37 @@ QString LocalDocsSearch::run(const QJsonObject &parameters, qint64 timeout)
3940
return worker.response();
4041
}
4142

43+
QJsonObject LocalDocsSearch::paramSchema() const
44+
{
45+
static const QString localParamSchema = R"({
46+
"collections": {
47+
"type": "array",
48+
"items": {
49+
"type": "string"
50+
},
51+
"description": "The collections to search",
52+
"required": true,
53+
"modelGenerated": false,
54+
"userConfigured": false
55+
},
56+
"query": {
57+
"type": "string",
58+
"description": "The query to search",
59+
"required": true
60+
},
61+
"count": {
62+
"type": "integer",
63+
"description": "The number of excerpts to return",
64+
"required": true,
65+
"modelGenerated": false
66+
}
67+
})";
68+
69+
static const QJsonDocument localJsonDoc = QJsonDocument::fromJson(localParamSchema.toUtf8());
70+
Q_ASSERT(!localJsonDoc.isNull() && localJsonDoc.isObject());
71+
return localJsonDoc.object();
72+
}
73+
4274
LocalDocsWorker::LocalDocsWorker()
4375
: QObject(nullptr)
4476
{

gpt4all-chat/localdocssearch.h

+9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ class LocalDocsSearch : public Tool {
3434
ToolEnums::Error error() const override { return m_error; }
3535
QString errorString() const override { return m_errorString; }
3636

37+
QString name() const override { return tr("LocalDocs search"); }
38+
QString description() const override { return tr("Search the local docs"); }
39+
QString function() const override { return "localdocs_search"; }
40+
QJsonObject paramSchema() const override;
41+
bool isEnabled() const override { return true; }
42+
bool isBuiltin() const override { return true; }
43+
bool forceUsage() const override { return true; }
44+
bool excerpts() const override { return true; }
45+
3746
private:
3847
ToolEnums::Error m_error;
3948
QString m_errorString;

gpt4all-chat/modellist.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,15 @@ void ModelInfo::setToolTemplate(const QString &t)
334334
m_toolTemplate = t;
335335
}
336336

337-
QString ModelInfo::systemPrompt() const
337+
QString ModelInfo::systemPromptTemplate() const
338338
{
339-
return MySettings::globalInstance()->modelSystemPrompt(*this);
339+
return MySettings::globalInstance()->modelSystemPromptTemplate(*this);
340340
}
341341

342-
void ModelInfo::setSystemPrompt(const QString &p)
342+
void ModelInfo::setSystemPromptTemplate(const QString &p)
343343
{
344-
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPrompt(*this, p, true /*force*/);
345-
m_systemPrompt = p;
344+
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPromptTemplate(*this, p, true /*force*/);
345+
m_systemPromptTemplate = p;
346346
}
347347

348348
QString ModelInfo::chatNamePrompt() const
@@ -397,7 +397,7 @@ QVariantMap ModelInfo::getFields() const
397397
{ "repeatPenaltyTokens", m_repeatPenaltyTokens },
398398
{ "promptTemplate", m_promptTemplate },
399399
{ "toolTemplate", m_toolTemplate },
400-
{ "systemPrompt", m_systemPrompt },
400+
{ "systemPromptTemplate",m_systemPromptTemplate },
401401
{ "chatNamePrompt", m_chatNamePrompt },
402402
{ "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt },
403403
};
@@ -792,7 +792,7 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
792792
case ToolTemplateRole:
793793
return info->toolTemplate();
794794
case SystemPromptRole:
795-
return info->systemPrompt();
795+
return info->systemPromptTemplate();
796796
case ChatNamePromptRole:
797797
return info->chatNamePrompt();
798798
case SuggestedFollowUpPromptRole:
@@ -970,7 +970,7 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
970970
case ToolTemplateRole:
971971
info->setToolTemplate(value.toString()); break;
972972
case SystemPromptRole:
973-
info->setSystemPrompt(value.toString()); break;
973+
info->setSystemPromptTemplate(value.toString()); break;
974974
case ChatNamePromptRole:
975975
info->setChatNamePrompt(value.toString()); break;
976976
case SuggestedFollowUpPromptRole:
@@ -1125,7 +1125,7 @@ QString ModelList::clone(const ModelInfo &model)
11251125
{ ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() },
11261126
{ ModelList::PromptTemplateRole, model.promptTemplate() },
11271127
{ ModelList::ToolTemplateRole, model.toolTemplate() },
1128-
{ ModelList::SystemPromptRole, model.systemPrompt() },
1128+
{ ModelList::SystemPromptRole, model.systemPromptTemplate() },
11291129
{ ModelList::ChatNamePromptRole, model.chatNamePrompt() },
11301130
{ ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() },
11311131
};

gpt4all-chat/modellist.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ struct ModelInfo {
6969
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
7070
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
7171
Q_PROPERTY(QString toolTemplate READ toolTemplate WRITE setToolTemplate)
72-
Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt)
72+
Q_PROPERTY(QString systemPromptTemplate READ systemPromptTemplate WRITE setSystemPromptTemplate)
7373
Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt)
7474
Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt)
7575
Q_PROPERTY(int likes READ likes WRITE setLikes)
@@ -181,8 +181,9 @@ struct ModelInfo {
181181
void setPromptTemplate(const QString &t);
182182
QString toolTemplate() const;
183183
void setToolTemplate(const QString &t);
184-
QString systemPrompt() const;
185-
void setSystemPrompt(const QString &p);
184+
QString systemPromptTemplate() const;
185+
void setSystemPromptTemplate(const QString &p);
186+
// FIXME (adam): The chatname and suggested follow-up should also be templates I guess?
186187
QString chatNamePrompt() const;
187188
void setChatNamePrompt(const QString &p);
188189
QString suggestedFollowUpPrompt() const;
@@ -219,7 +220,7 @@ struct ModelInfo {
219220
int m_repeatPenaltyTokens = 64;
220221
QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n";
221222
QString m_toolTemplate = "";
222-
QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n";
223+
QString m_systemPromptTemplate = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n";
223224
QString m_chatNamePrompt = "Describe the above conversation in seven words or less.";
224225
QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts.";
225226
friend class MySettings;

gpt4all-chat/mysettings.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &info)
194194
setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens);
195195
setModelPromptTemplate(info, info.m_promptTemplate);
196196
setModelToolTemplate(info, info.m_toolTemplate);
197-
setModelSystemPrompt(info, info.m_systemPrompt);
197+
setModelSystemPromptTemplate(info, info.m_systemPromptTemplate);
198198
setModelChatNamePrompt(info, info.m_chatNamePrompt);
199199
setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt);
200200
}
@@ -297,7 +297,7 @@ double MySettings::modelRepeatPenalty (const ModelInfo &info) const
297297
int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); }
298298
QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); }
299299
QString MySettings::modelToolTemplate (const ModelInfo &info) const { return getModelSetting("toolTemplate", info).toString(); }
300-
QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); }
300+
QString MySettings::modelSystemPromptTemplate (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); }
301301
QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); }
302302
QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); }
303303

@@ -411,7 +411,7 @@ void MySettings::setModelToolTemplate(const ModelInfo &info, const QString &valu
411411
setModelSetting("toolTemplate", info, value, force, true);
412412
}
413413

414-
void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force)
414+
void MySettings::setModelSystemPromptTemplate(const ModelInfo &info, const QString &value, bool force)
415415
{
416416
setModelSetting("systemPrompt", info, value, force, true);
417417
}

gpt4all-chat/mysettings.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ class MySettings : public QObject
128128
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false);
129129
QString modelToolTemplate(const ModelInfo &info) const;
130130
Q_INVOKABLE void setModelToolTemplate(const ModelInfo &info, const QString &value, bool force = false);
131-
QString modelSystemPrompt(const ModelInfo &info) const;
132-
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false);
131+
QString modelSystemPromptTemplate(const ModelInfo &info) const;
132+
Q_INVOKABLE void setModelSystemPromptTemplate(const ModelInfo &info, const QString &value, bool force = false);
133133
int modelContextLength(const ModelInfo &info) const;
134134
Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false);
135135
int modelGpuLayers(const ModelInfo &info) const;

gpt4all-chat/tool.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -1 +1,31 @@
11
#include "tool.h"
2+
3+
#include <QJsonDocument>
4+
5+
QJsonObject filterModelGeneratedProperties(const QJsonObject &inputObject) {
6+
QJsonObject filteredObject;
7+
for (const QString &key : inputObject.keys()) {
8+
QJsonObject propertyObject = inputObject.value(key).toObject();
9+
if (!propertyObject.contains("modelGenerated") || propertyObject["modelGenerated"].toBool())
10+
filteredObject.insert(key, propertyObject);
11+
}
12+
return filteredObject;
13+
}
14+
15+
jinja2::Value Tool::jinjaValue() const
16+
{
17+
QJsonDocument doc(filterModelGeneratedProperties(paramSchema()));
18+
QString p(doc.toJson(QJsonDocument::Compact));
19+
20+
QJsonDocument exampleDoc(exampleParams());
21+
QString e(exampleDoc.toJson(QJsonDocument::Compact));
22+
23+
jinja2::ValuesMap params {
24+
{ "name", name().toStdString() },
25+
{ "description", description().toStdString() },
26+
{ "function", function().toStdString() },
27+
{ "paramSchema", p.toStdString() },
28+
{ "exampleParams", e.toStdString() }
29+
};
30+
return params;
31+
}

0 commit comments

Comments
 (0)