Skip to content

Commit

Permalink
feat: ensure retriever returns an image and send it to the LLM base64…
Browse files Browse the repository at this point in the history
… encoded
  • Loading branch information
marcusschiesser committed Dec 18, 2023
1 parent f4b0961 commit 127a8ca
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 346 deletions.
6 changes: 6 additions & 0 deletions .changeset/large-plums-drum.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"llamaindex": patch
---

Added support for multi-modal RAG (retriever and query engine) incl. an example
Fixed persisting and loading image vector stores
Binary file removed examples/multimodal/data/1.jpg
Binary file not shown.
Binary file removed examples/multimodal/data/2.jpg
Binary file not shown.
Binary file removed examples/multimodal/data/3.jpg
Binary file not shown.
323 changes: 0 additions & 323 deletions examples/multimodal/data/San Francisco.txt

This file was deleted.

23 changes: 19 additions & 4 deletions examples/multimodal/rag.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import {
CallbackManager,
ImageDocument,
ImageType,
MultiModalResponseSynthesizer,
NodeWithScore,
OpenAI,
ServiceContext,
VectorStoreIndex,
Expand All @@ -21,23 +25,34 @@ export async function createIndex(serviceContext: ServiceContext) {
}

async function main() {
let images: ImageType[] = [];
const callbackManager = new CallbackManager({
onRetrieve: ({ query, nodes }) => {
images = nodes
.filter(({ node }: NodeWithScore) => node instanceof ImageDocument)
.map(({ node }: NodeWithScore) => (node as ImageDocument).image);
},
});
const llm = new OpenAI({ model: "gpt-4-vision-preview", maxTokens: 512 });
const serviceContext = serviceContextFromDefaults({
llm,
chunkSize: 512,
chunkOverlap: 20,
callbackManager,
});
const index = await createIndex(serviceContext);

const queryEngine = index.asQueryEngine({
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
// TODO: set imageSimilarityTopK: 1,
retriever: index.asRetriever({ similarityTopK: 2 }),
retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }),
});
const result = await queryEngine.query(
"what are Vincent van Gogh's famous paintings",
"Tell me more about Vincent van Gogh's famous paintings",
);
console.log(result.response, "\n");
images.forEach((image) =>
console.log(`Image retrieved and used in inference: ${image.toString()}`),
);
console.log(result.response);
}

main().catch(console.error);
3 changes: 1 addition & 2 deletions examples/multimodal/retrieve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
TextNode,
VectorStoreIndex,
} from "llamaindex";
import * as path from "path";

export async function createIndex() {
// set up vector store index with two vector stores, one for text, the other for images
Expand Down Expand Up @@ -37,7 +36,7 @@ async function main() {
continue;
}
if (node instanceof ImageNode) {
console.log(`Image: ${path.join(__dirname, node.id_)}`);
console.log(`Image: ${node.getUrl()}`);
} else if (node instanceof TextNode) {
console.log("Text:", (node as TextNode).text.substring(0, 128));
}
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"@xenova/transformers": "^2.10.0",
"assemblyai": "^4.0.0",
"crypto-js": "^4.2.0",
"file-type": "^18.7.0",
"js-tiktoken": "^1.0.8",
"lodash": "^4.17.21",
"mammoth": "^1.6.0",
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/Node.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import CryptoJS from "crypto-js";
import path from "path";
import { v4 as uuidv4 } from "uuid";

export enum NodeRelationship {
Expand Down Expand Up @@ -304,6 +305,12 @@ export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
getType(): ObjectType {
return ObjectType.IMAGE;
}

getUrl(): URL {
// id_ stores the relative path, convert it to the URL of the file
const absPath = path.resolve(this.id_);
return new URL(`file://${absPath}`);
}
}

export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {
Expand Down
62 changes: 61 additions & 1 deletion packages/core/src/embeddings/utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import _ from "lodash";
import { ImageType } from "../Node";
import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
import { VectorStoreQueryMode } from "../storage";
import { DEFAULT_FS, VectorStoreQueryMode } from "../storage";
import { SimilarityType } from "./types";

/**
Expand Down Expand Up @@ -185,6 +185,16 @@ export function getTopKMMREmbeddings(
return [resultSimilarities, resultIds];
}

async function blobToDataUrl(input: Blob) {
const { fileTypeFromBuffer } = await import("file-type");
const buffer = Buffer.from(await input.arrayBuffer());
const type = await fileTypeFromBuffer(buffer);
if (!type) {
throw new Error("Unsupported image type");
}
return "data:" + type.mime + ";base64," + buffer.toString("base64");
}

export async function readImage(input: ImageType) {
const { RawImage } = await import("@xenova/transformers");
if (input instanceof Blob) {
Expand All @@ -195,3 +205,53 @@ export async function readImage(input: ImageType) {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export async function imageToString(input: ImageType): Promise<string> {
if (input instanceof Blob) {
// if the image is a Blob, convert it to a base64 data URL
return await blobToDataUrl(input);
} else if (_.isString(input)) {
return input;
} else if (input instanceof URL) {
return input.toString();
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export function stringToImage(input: string): ImageType {
if (input.startsWith("data:")) {
// if the input is a base64 data URL, convert it back to a Blob
const base64Data = input.split(",")[1];
const byteArray = Buffer.from(base64Data, "base64");
return new Blob([byteArray]);
} else if (input.startsWith("http://") || input.startsWith("https://")) {
return new URL(input);
} else if (_.isString(input)) {
return input;
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export async function imageToDataUrl(input: ImageType): Promise<string> {
// first ensure, that the input is a Blob
if (
(input instanceof URL && input.protocol === "file:") ||
_.isString(input)
) {
// string or file URL
const fs = DEFAULT_FS;
const dataBuffer = await fs.readFile(
input instanceof URL ? input.pathname : input,
);
input = new Blob([dataBuffer]);
} else if (!(input instanceof Blob)) {
if (input instanceof URL) {
throw new Error(`Unsupported URL with protocol: ${input.protocol}`);
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
return await blobToDataUrl(input);
}
31 changes: 23 additions & 8 deletions packages/core/src/indices/vectorStore/VectorIndexRetriever.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Event } from "../../callbacks/CallbackManager";
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
import { BaseEmbedding } from "../../embeddings";
import { globalsHelper } from "../../GlobalsHelper";
import { Metadata, NodeWithScore } from "../../Node";
import { ImageNode, Metadata, NodeWithScore } from "../../Node";
import { BaseRetriever } from "../../Retriever";
import { ServiceContext } from "../../ServiceContext";
import { Event } from "../../callbacks/CallbackManager";
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
import { BaseEmbedding } from "../../embeddings";
import {
VectorStoreQuery,
VectorStoreQueryMode,
Expand All @@ -18,20 +18,23 @@ import { VectorStoreIndex } from "./VectorStoreIndex";

export class VectorIndexRetriever implements BaseRetriever {
index: VectorStoreIndex;
similarityTopK;
similarityTopK: number;
imageSimilarityTopK: number;
private serviceContext: ServiceContext;

constructor({
index,
similarityTopK,
imageSimilarityTopK,
}: {
index: VectorStoreIndex;
similarityTopK?: number;
imageSimilarityTopK?: number;
}) {
this.index = index;
this.serviceContext = this.index.serviceContext;

this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K;
this.imageSimilarityTopK = imageSimilarityTopK ?? DEFAULT_SIMILARITY_TOP_K;
}

async retrieve(
Expand All @@ -51,7 +54,11 @@ export class VectorIndexRetriever implements BaseRetriever {
query: string,
preFilters?: unknown,
): Promise<NodeWithScore[]> {
const q = await this.buildVectorStoreQuery(this.index.embedModel, query);
const q = await this.buildVectorStoreQuery(
this.index.embedModel,
query,
this.similarityTopK,
);
const result = await this.index.vectorStore.query(q, preFilters);
return this.buildNodeListFromQueryResult(result);
}
Expand All @@ -64,6 +71,7 @@ export class VectorIndexRetriever implements BaseRetriever {
const q = await this.buildVectorStoreQuery(
this.index.imageEmbedModel,
query,
this.imageSimilarityTopK,
);
const result = await this.index.imageVectorStore.query(q, preFilters);
return this.buildNodeListFromQueryResult(result);
Expand All @@ -89,13 +97,14 @@ export class VectorIndexRetriever implements BaseRetriever {
protected async buildVectorStoreQuery(
embedModel: BaseEmbedding,
query: string,
similarityTopK: number,
): Promise<VectorStoreQuery> {
const queryEmbedding = await embedModel.getQueryEmbedding(query);

return {
queryEmbedding: queryEmbedding,
mode: VectorStoreQueryMode.DEFAULT,
similarityTopK: this.similarityTopK,
similarityTopK: similarityTopK,
};
}

Expand All @@ -108,6 +117,12 @@ export class VectorIndexRetriever implements BaseRetriever {
}

const node = this.index.indexStruct.nodesDict[result.ids[i]];
// XXX: Hack, if it's an image node, we reconstruct the image from the URL
// Alternative: Store image in doc store and retrieve it here
if (node instanceof ImageNode) {
node.image = node.getUrl();
}

nodesWithScores.push({
node: node,
score: result.similarities[i],
Expand Down
26 changes: 18 additions & 8 deletions packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import { MessageContentDetail } from "../ChatEngine";
import { MetadataMode, NodeWithScore, splitNodesByType } from "../Node";
import {
ImageNode,
MetadataMode,
NodeWithScore,
splitNodesByType,
} from "../Node";
import { Response } from "../Response";
import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext";
import { Event } from "../callbacks/CallbackManager";
import { imageToDataUrl } from "../embeddings";
import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt";
import { BaseSynthesizer } from "./types";

Expand Down Expand Up @@ -34,15 +40,19 @@ export class MultiModalResponseSynthesizer implements BaseSynthesizer {
// TODO: use builders to generate context
const context = textChunks.join("\n\n");
const textPrompt = this.textQATemplate({ context, query });
// TODO: get images from imageNodes
const images = await Promise.all(
imageNodes.map(async (node: ImageNode) => {
return {
type: "image_url",
image_url: {
url: await imageToDataUrl(node.image),
},
} as MessageContentDetail;
}),
);
const prompt: MessageContentDetail[] = [
{ type: "text", text: textPrompt },
{
type: "image_url",
image_url: {
url: "https://upload.wikimedia.org/wikipedia/commons/b/b0/Vincent_van_Gogh_%281853-1890%29_Caf%C3%A9terras_bij_nacht_%28place_du_Forum%29_Kr%C3%B6ller-M%C3%BCller_Museum_Otterlo_23-8-2016_13-35-40.JPG",
},
},
...images,
];
let response = await this.serviceContext.llm.complete(prompt, parentEvent);
return new Response(response.message.content, nodes);
Expand Down
Loading

0 comments on commit 127a8ca

Please sign in to comment.