Skip to content

Commit 8de4954

Browse files
committed
Refactor the brave search and introduce an abstraction for tool calls.
Signed-off-by: Adam Treat <[email protected]>
1 parent 3f8ee0e commit 8de4954

10 files changed

+366
-140
lines changed

gpt4all-chat/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,10 @@ qt_add_executable(chat
123123
modellist.h modellist.cpp
124124
mysettings.h mysettings.cpp
125125
network.h network.cpp
126-
sourceexcerpt.h
126+
sourceexcerpt.h sourceexcerpt.cpp
127127
server.h server.cpp
128128
logger.h logger.cpp
129+
tool.h tool.cpp
129130
${APP_ICON_RESOURCE}
130131
${CHAT_EXE_RESOURCES}
131132
)

gpt4all-chat/bravesearch.cpp

+51-124
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616

1717
using namespace Qt::Literals::StringLiterals;
1818

19-
QPair<QString, QList<SourceExcerpt>> BraveSearch::search(const QString &apiKey, const QString &query, int topK, unsigned long timeout)
19+
QString BraveSearch::run(const QJsonObject &parameters, qint64 timeout)
2020
{
21+
const QString apiKey = parameters["apiKey"].toString();
22+
const QString query = parameters["query"].toString();
23+
const int count = parameters["count"].toInt();
2124
QThread workerThread;
2225
BraveAPIWorker worker;
2326
worker.moveToThread(&workerThread);
2427
connect(&worker, &BraveAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
25-
connect(this, &BraveSearch::request, &worker, &BraveAPIWorker::request, Qt::QueuedConnection);
28+
connect(&workerThread, &QThread::started, [&worker, apiKey, query, count]() {
29+
worker.request(apiKey, query, count);
30+
});
2631
workerThread.start();
27-
emit request(apiKey, query, topK);
2832
workerThread.wait(timeout);
2933
workerThread.quit();
3034
workerThread.wait();
@@ -34,174 +38,97 @@ QPair<QString, QList<SourceExcerpt>> BraveSearch::search(const QString &apiKey,
3438
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK)
3539
{
3640
m_topK = topK;
41+
42+
// Documentation on the brave web search:
43+
// https://api.search.brave.com/app/documentation/web-search/get-started
3744
QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search");
45+
46+
// Documentation on the query options:
47+
//https://api.search.brave.com/app/documentation/web-search/query
3848
QUrlQuery urlQuery;
3949
urlQuery.addQueryItem("q", query);
50+
urlQuery.addQueryItem("count", QString::number(topK));
51+
urlQuery.addQueryItem("result_filter", "web");
52+
urlQuery.addQueryItem("extra_snippets", "true");
4053
jsonUrl.setQuery(urlQuery);
4154
QNetworkRequest request(jsonUrl);
4255
QSslConfiguration conf = request.sslConfiguration();
4356
conf.setPeerVerifyMode(QSslSocket::VerifyNone);
4457
request.setSslConfiguration(conf);
45-
4658
request.setRawHeader("X-Subscription-Token", apiKey.toUtf8());
47-
// request.setRawHeader("Accept-Encoding", "gzip");
4859
request.setRawHeader("Accept", "application/json");
49-
5060
m_networkManager = new QNetworkAccessManager(this);
5161
QNetworkReply *reply = m_networkManager->get(request);
5262
connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
5363
connect(reply, &QNetworkReply::finished, this, &BraveAPIWorker::handleFinished);
5464
connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred);
5565
}
5666

57-
static QPair<QString, QList<SourceExcerpt>> cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
67+
static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
5868
{
69+
// This parses the response from brave and formats it in json that conforms to the de facto
70+
// standard in SourceExcerpts::fromJson(...)
5971
QJsonParseError err;
6072
QJsonDocument document = QJsonDocument::fromJson(jsonResponse, &err);
6173
if (err.error != QJsonParseError::NoError) {
62-
qWarning() << "ERROR: Couldn't parse: " << jsonResponse << err.errorString();
63-
return QPair<QString, QList<SourceExcerpt>>();
74+
qWarning() << "ERROR: Couldn't parse brave response: " << jsonResponse << err.errorString();
75+
return QString();
6476
}
6577

78+
QString query;
6679
QJsonObject searchResponse = document.object();
6780
QJsonObject cleanResponse;
68-
QString query;
6981
QJsonArray cleanArray;
7082

71-
QList<SourceExcerpt> infos;
72-
7383
if (searchResponse.contains("query")) {
7484
QJsonObject queryObj = searchResponse["query"].toObject();
75-
if (queryObj.contains("original")) {
85+
if (queryObj.contains("original"))
7686
query = queryObj["original"].toString();
77-
}
7887
}
7988

8089
if (searchResponse.contains("mixed")) {
8190
QJsonObject mixedResults = searchResponse["mixed"].toObject();
8291
QJsonArray mainResults = mixedResults["main"].toArray();
92+
QJsonObject resultsObject = searchResponse["web"].toObject();
93+
QJsonArray resultsArray = resultsObject["results"].toArray();
8394

84-
for (int i = 0; i < std::min(mainResults.size(), topK); ++i) {
95+
for (int i = 0; i < std::min(mainResults.size(), resultsArray.size()); ++i) {
8596
QJsonObject m = mainResults[i].toObject();
8697
QString r_type = m["type"].toString();
87-
int idx = m["index"].toInt();
88-
QJsonObject resultsObject = searchResponse[r_type].toObject();
89-
QJsonArray resultsArray = resultsObject["results"].toArray();
90-
91-
QJsonValue cleaned;
92-
SourceExcerpt info;
93-
if (r_type == "web") {
94-
// For web data - add a single output from the search
95-
QJsonObject resultObj = resultsArray[idx].toObject();
96-
QStringList selectedKeys = {"type", "title", "url", "description", "date", "extra_snippets"};
97-
QJsonObject cleanedObj;
98-
for (const auto& key : selectedKeys) {
99-
if (resultObj.contains(key)) {
100-
cleanedObj.insert(key, resultObj[key]);
101-
}
102-
}
103-
104-
QStringList textKeys = {"description", "extra_snippets"};
105-
QJsonObject textObj;
106-
for (const auto& key : textKeys) {
107-
if (resultObj.contains(key)) {
108-
textObj.insert(key, resultObj[key]);
109-
}
98+
Q_ASSERT(r_type == "web");
99+
const int idx = m["index"].toInt();
100+
101+
QJsonObject resultObj = resultsArray[idx].toObject();
102+
QStringList selectedKeys = {"type", "title", "url", "description"};
103+
QJsonObject result;
104+
for (const auto& key : selectedKeys)
105+
if (resultObj.contains(key))
106+
result.insert(key, resultObj[key]);
107+
108+
if (resultObj.contains("page_age"))
109+
result.insert("date", resultObj["page_age"]);
110+
111+
QJsonArray excerpts;
112+
if (resultObj.contains("extra_snippets")) {
113+
QJsonArray snippets = resultObj["extra_snippets"].toArray();
114+
for (int i = 0; i < snippets.size(); ++i) {
115+
QString snippet = snippets[i].toString();
116+
QJsonObject excerpt;
117+
excerpt.insert("text", snippet);
118+
excerpts.append(excerpt);
110119
}
111-
112-
QJsonDocument textObjDoc(textObj);
113-
info.date = resultObj["date"].toString();
114-
info.text = textObjDoc.toJson(QJsonDocument::Indented);
115-
info.url = resultObj["url"].toString();
116-
QJsonObject meta_url = resultObj["meta_url"].toObject();
117-
info.favicon = meta_url["favicon"].toString();
118-
info.title = resultObj["title"].toString();
119-
120-
cleaned = cleanedObj;
121-
} else if (r_type == "faq") {
122-
// For faq data - take a list of all the questions & answers
123-
QStringList selectedKeys = {"type", "question", "answer", "title", "url"};
124-
QJsonArray cleanedArray;
125-
for (const auto& q : resultsArray) {
126-
QJsonObject qObj = q.toObject();
127-
QJsonObject cleanedObj;
128-
for (const auto& key : selectedKeys) {
129-
if (qObj.contains(key)) {
130-
cleanedObj.insert(key, qObj[key]);
131-
}
132-
}
133-
cleanedArray.append(cleanedObj);
134-
}
135-
cleaned = cleanedArray;
136-
} else if (r_type == "infobox") {
137-
QJsonObject resultObj = resultsArray[idx].toObject();
138-
QStringList selectedKeys = {"type", "title", "url", "description", "long_desc"};
139-
QJsonObject cleanedObj;
140-
for (const auto& key : selectedKeys) {
141-
if (resultObj.contains(key)) {
142-
cleanedObj.insert(key, resultObj[key]);
143-
}
144-
}
145-
cleaned = cleanedObj;
146-
} else if (r_type == "videos") {
147-
QStringList selectedKeys = {"type", "url", "title", "description", "date"};
148-
QJsonArray cleanedArray;
149-
for (const auto& q : resultsArray) {
150-
QJsonObject qObj = q.toObject();
151-
QJsonObject cleanedObj;
152-
for (const auto& key : selectedKeys) {
153-
if (qObj.contains(key)) {
154-
cleanedObj.insert(key, qObj[key]);
155-
}
156-
}
157-
cleanedArray.append(cleanedObj);
158-
}
159-
cleaned = cleanedArray;
160-
} else if (r_type == "locations") {
161-
QStringList selectedKeys = {"type", "title", "url", "description", "coordinates", "postal_address", "contact", "rating", "distance", "zoom_level"};
162-
QJsonArray cleanedArray;
163-
for (const auto& q : resultsArray) {
164-
QJsonObject qObj = q.toObject();
165-
QJsonObject cleanedObj;
166-
for (const auto& key : selectedKeys) {
167-
if (qObj.contains(key)) {
168-
cleanedObj.insert(key, qObj[key]);
169-
}
170-
}
171-
cleanedArray.append(cleanedObj);
172-
}
173-
cleaned = cleanedArray;
174-
} else if (r_type == "news") {
175-
QStringList selectedKeys = {"type", "title", "url", "description"};
176-
QJsonArray cleanedArray;
177-
for (const auto& q : resultsArray) {
178-
QJsonObject qObj = q.toObject();
179-
QJsonObject cleanedObj;
180-
for (const auto& key : selectedKeys) {
181-
if (qObj.contains(key)) {
182-
cleanedObj.insert(key, qObj[key]);
183-
}
184-
}
185-
cleanedArray.append(cleanedObj);
186-
}
187-
cleaned = cleanedArray;
188-
} else {
189-
cleaned = QJsonValue();
190120
}
191-
192-
infos.append(info);
193-
cleanArray.append(cleaned);
121+
result.insert("excerpts", excerpts);
122+
cleanArray.append(QJsonValue(result));
194123
}
195124
}
196125

197126
cleanResponse.insert("query", query);
198-
cleanResponse.insert("top_k", cleanArray);
127+
cleanResponse.insert("results", cleanArray);
199128
QJsonDocument cleanedDoc(cleanResponse);
200-
201129
// qDebug().noquote() << document.toJson(QJsonDocument::Indented);
202130
// qDebug().noquote() << cleanedDoc.toJson(QJsonDocument::Indented);
203-
204-
return qMakePair(cleanedDoc.toJson(QJsonDocument::Indented), infos);
131+
return cleanedDoc.toJson(QJsonDocument::Compact);
205132
}
206133

207134
void BraveAPIWorker::handleFinished()

gpt4all-chat/bravesearch.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define BRAVESEARCH_H
33

44
#include "sourceexcerpt.h"
5+
#include "tool.h"
56

67
#include <QObject>
78
#include <QString>
@@ -17,7 +18,7 @@ class BraveAPIWorker : public QObject {
1718
, m_topK(1) {}
1819
virtual ~BraveAPIWorker() {}
1920

20-
QPair<QString, QList<SourceExcerpt>> response() const { return m_response; }
21+
QString response() const { return m_response; }
2122

2223
public Q_SLOTS:
2324
void request(const QString &apiKey, const QString &query, int topK);
@@ -31,21 +32,17 @@ private Q_SLOTS:
3132

3233
private:
3334
QNetworkAccessManager *m_networkManager;
34-
QPair<QString, QList<SourceExcerpt>> m_response;
35+
QString m_response;
3536
int m_topK;
3637
};
3738

38-
class BraveSearch : public QObject {
39+
class BraveSearch : public Tool {
3940
Q_OBJECT
4041
public:
41-
BraveSearch()
42-
: QObject(nullptr) {}
42+
BraveSearch() : Tool() {}
4343
virtual ~BraveSearch() {}
4444

45-
QPair<QString, QList<SourceExcerpt>> search(const QString &apiKey, const QString &query, int topK, unsigned long timeout = 2000);
46-
47-
Q_SIGNALS:
48-
void request(const QString &apiKey, const QString &query, int topK);
45+
QString run(const QJsonObject &parameters, qint64 timeout = 2000) override;
4946
};
5047

5148
#endif // BRAVESEARCH_H

gpt4all-chat/chatllm.cpp

+17-5
Original file line numberDiff line numberDiff line change
@@ -880,14 +880,26 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
880880

881881
const QString query = args["query"].toString();
882882

883-
// FIXME: This has to handle errors of the tool call
884883
emit toolCalled(tr("searching web..."));
885884
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
886885
Q_ASSERT(apiKey != "");
887886
BraveSearch brave;
888-
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/,
889-
2000 /*msecs to timeout*/);
890-
emit sourceExcerptsChanged(braveResponse.second);
887+
888+
QJsonObject parameters;
889+
parameters.insert("apiKey", apiKey);
890+
parameters.insert("query", query);
891+
parameters.insert("count", 2);
892+
893+
// FIXME: This has to handle errors of the tool call
894+
const QString braveResponse = brave.run(parameters, 2000 /*msecs to timeout*/);
895+
896+
QString parseError;
897+
QList<SourceExcerpt> sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError);
898+
if (!parseError.isEmpty()) {
899+
qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError;
900+
} else if (!sourceExcerpts.isEmpty()) {
901+
emit sourceExcerptsChanged(sourceExcerpts);
902+
}
891903

892904
// Erase the context of the tool call
893905
m_ctx.n_past = std::max(0, m_ctx.n_past);
@@ -898,7 +910,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
898910

899911
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
900912
// tool calls
901-
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
913+
return promptInternal(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
902914
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
903915
true /*isToolCallResponse*/);
904916

gpt4all-chat/qml/ChatView.qml

+8-1
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,14 @@ Rectangle {
11331133
sourceSize.width: 24
11341134
sourceSize.height: 24
11351135
mipmap: true
1136-
source: consolidatedSources[0].url === "" ? "qrc:/gpt4all/icons/db.svg" : "qrc:/gpt4all/icons/globe.svg"
1136+
source: {
1137+
if (typeof consolidatedSources === 'undefined'
1138+
|| typeof consolidatedSources[0] === 'undefined'
1139+
|| consolidatedSources[0].url === "")
1140+
return "qrc:/gpt4all/icons/db.svg";
1141+
else
1142+
return "qrc:/gpt4all/icons/globe.svg";
1143+
}
11371144
}
11381145

11391146
ColorOverlay {

0 commit comments

Comments
 (0)