diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/RedisVectorStore.py b/pkgs/swarmauri/swarmauri/vector_stores/concrete/RedisVectorStore.py new file mode 100644 index 000000000..62f8cd098 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/vector_stores/concrete/RedisVectorStore.py @@ -0,0 +1,210 @@ +import json +from typing import List, Union, Literal, Dict + +import numpy as np +import redis +from redis.commands.search.field import VectorField, TextField, TagField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query + +from swarmauri.standard.vectors.concrete.Vector import Vector +from swarmauri.standard.documents.concrete.Document import Document +from swarmauri.standard.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding # or your specific embedder +from swarmauri.standard.vector_stores.base.VectorStoreBase import VectorStoreBase +from swarmauri.standard.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin +from swarmauri.standard.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin + + +class RedisVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase): + type: Literal["RedisVectorStore"] = "RedisVectorStore" + index_name: str = "documents_index" + embedding_dimension: int = 8000 + + def __init__( + self, + redis_host: str = "localhost", + redis_port: int = 6379, + redis_password: str = None, + embedding_dimension: int = 8000, # Adjust based on your embedder + **kwargs, + ): + super().__init__(**kwargs) + self._embedder = Doc2VecEmbedding() # Replace with your specific embedder if different + self.embedding_dimension = embedding_dimension + + # Initialize Redis client + self.redis_client = redis.Redis( + host=redis_host, + port=redis_port, + password=redis_password, + decode_responses=False # For binary data + ) + + self.redis_client.ft(self.index_name).dropindex(delete_documents=False) + vector_field = VectorField( + "embedding", + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": self.embedding_dimension, + "DISTANCE_METRIC": "COSINE" + } + ) + text_field = TextField("content") + + try: + from redis.commands.search import Search + self.redis_client.ft(self.index_name).info() + except redis.exceptions.ResponseError: + schema = ( + text_field, + vector_field + ) + definition = IndexDefinition( + prefix=["doc:"], + index_type=IndexType.HASH + ) + self.redis_client.ft(self.index_name).create_index( + fields=schema, + definition=definition + ) + + def _doc_key(self, document_id: str) -> str: + return f"doc:{document_id}" + + def add_document(self, document: Document) -> None: + doc = document + pipeline = self.redis_client.pipeline() + + # Embed the document content + embedding = self._embedder.fit_transform([doc.content])[0] + + if isinstance(embedding, Vector): + embedding = embedding.value + metadata = doc.metadata + + # print("METADATA ::::::::::::::::::::", metadata) + doc_key = self._doc_key(doc.id) + # print("DOC KEY ::::::::::::::::::::", doc_key) + pipeline.hset(doc_key, mapping={ + "content": doc.content, + "metadata": json.dumps(metadata), # Store metadata as JSON + "embedding": np.array(embedding, dtype=np.float32).tobytes() # Convert embedding values to bytes + }) + add = pipeline.execute() + + def add_documents(self, documents: List[Document]) -> None: + pipeline = self.redis_client.pipeline() + for doc in documents: + if not doc.content: + continue + # Embed the document content + embedding = self._embedder.fit_transform([doc.content])[0] + + if isinstance(embedding, Vector): + embedding = embedding.value + metadata={doc.metadata} + + doc_key = self._doc_key(doc.id) + pipeline.hset(doc_key, mapping={ + "content": doc.content, + "metadata": json.dumps(metadata), + "embedding": np.array(embedding, dtype=np.float32).tobytes() + }) + pipeline.execute() + + def get_document(self, id: str) -> Union[Document, None]: + + doc_key = self._doc_key(id) + data = self.redis_client.hgetall(doc_key) + if not data: + return None + + metadata_raw = data.get(b"metadata", b"{}").decode("utf-8") + metadata = json.loads(metadata_raw) + + content = data.get(b"content", b"").decode("utf-8") + # print("METAAAAAAA ::::::::::::", metadata) + + embedding_bytes = data.get(b"embedding") + if embedding_bytes: + embedding = Vector(value=np.frombuffer(embedding_bytes, dtype=np.float32).tolist()) + else: + embedding = None + return Document( + id=id, + content=content, + metadata=metadata, + embedding=embedding + ) + + def get_all_documents(self) -> List[Document]: + cursor = '0' + documents = [] + while cursor != 0: + cursor, keys = self.redis_client.scan(cursor=cursor, match="doc:*", count=1000) + for key in keys: + data = self.redis_client.hgetall(key) + if not data: + continue + doc_id = key.decode("utf-8").split("doc:")[1] + metadata_raw = data.get(b"metadata", b"{}").decode("utf-8") + metadata = json.loads(metadata_raw) + content = data.get(b"content", b"").decode("utf-8") + embedding_bytes = data.get(b"embedding") + if embedding_bytes: + embedding = Vector(value=np.frombuffer(embedding_bytes, dtype=np.float32).tolist()) + else: + embedding = None + document = Document( + id=doc_id, + content=content, + metadata=metadata, + embedding=embedding + ) + documents.append(document) + return documents + + def delete_document(self, id: str) -> None: + doc_key = self._doc_key(id) + self.redis_client.delete(doc_key) + + def update_document(self, document: Document) -> None: + doc_key = self._doc_key(document.id) + if not self.redis_client.exists(doc_key): + raise ValueError(f"Document with id {document.id} does not exist.") + # Update the document by re-adding it + self.add_documents([document]) + + + def cosine_similarity(self, vec1, vec2): + dot_product = np.dot(vec1, vec2) + norm_vec1 = np.linalg.norm(vec1) + norm_vec2 = np.linalg.norm(vec2) + if norm_vec1 == 0 or norm_vec2 == 0: + return 0 + return dot_product / (norm_vec1 * norm_vec2) + + + def retrieve(self, query: str, top_k: int = 5) -> List[Document]: + query_vector = self._embedder.infer_vector(query) + + all_documents = self.get_all_documents() + # print("ALL DOCUMENTS ::::::::::::::::::::", all_documents[:10]) + similarities = [] + for doc in all_documents: + if doc.embedding is not None: + doc_vector = doc.embedding + # print("DOC VECTOR ::::::::::::::::::::", doc_vector.value[:10]) + similarity = self.cosine_similarity(query_vector.value, doc_vector.value) + similarities.append((doc, similarity)) + + similarities.sort(key=lambda x: x[1], reverse=True) + # print("SIMILARITIES ::::::::::::::::::::", similarities[:10]) + top_documents = [doc for doc, _ in similarities[:top_k]] + # print(f"Found {len(top_documents)} similar documents.") + return top_documents + + + class Config: + extra = 'allow' diff --git a/pkgs/swarmauri/tests/unit/vector_stores/RedisVectorStore_test.py b/pkgs/swarmauri/tests/unit/vector_stores/RedisVectorStore_test.py new file mode 100644 index 000000000..9f5942ab4 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/vector_stores/RedisVectorStore_test.py @@ -0,0 +1,89 @@ +import pytest +import numpy as np +from swarmauri.standard.documents.concrete.Document import Document +from swarmauri.vector_stores.concrete.RedisVectorStore import RedisVectorStore + +@pytest.fixture(scope="module") +def vector_store(): + vector_store = RedisVectorStore( + redis_host="redis-12648.c305.ap-south-1-1.ec2.redns.redis-cloud.com", + redis_port=12648, + redis_password='EaNg3YcgUW94Uj1P5wT3LScNtM97avu2', # Replace with your password if needed + embedding_dimension=8000 # Adjust based on your embedder +) + return vector_store + +# Create a sample document +@pytest.fixture +def sample_document(): + return Document( + id="test_doc1", + content="This is a test document for unit testing.", + metadata={"category": "test"} + ) + +@pytest.mark.unit +def test_ubc_resource(): + vs = RedisVectorStore( + redis_host="redis-12648.c305.ap-south-1-1.ec2.redns.redis-cloud.com", + redis_port=12648, + redis_password='EaNg3YcgUW94Uj1P5wT3LScNtM97avu2', # Replace with your password if needed + embedding_dimension=8000 # Adjust based on your embedder + ) + assert vs.resource == 'VectorStore' + +@pytest.mark.unit +def test_ubc_type(): + vs = RedisVectorStore(redis_host="redis-12648.c305.ap-south-1-1.ec2.redns.redis-cloud.com", + redis_port=12648, + redis_password='EaNg3YcgUW94Uj1P5wT3LScNtM97avu2', + embedding_dimension=8000) + assert vs.type == 'RedisVectorStore' + + +@pytest.mark.unit +def top_k_test(vs = vector_store): + documents = [Document(content="test"), + Document(content='test1'), + Document(content='test2'), + Document(content='test3')] + + vs.add_documents(documents) + assert len(vs.retrieve(query='test', top_k=2)) == 2 + + +@pytest.mark.unit +def test_add_and_get_document(vector_store, sample_document): + vector_store.add_document(sample_document) + + retrieved_doc = vector_store.get_document("test_doc1") + + assert retrieved_doc is not None + assert retrieved_doc.id == "test_doc1" + assert retrieved_doc.content == "This is a test document for unit testing." + assert retrieved_doc.metadata == {"category": "test"} + + + +@pytest.mark.unit +def test_delete_document(vector_store, sample_document): + vector_store.add_document(sample_document) + vector_store.delete_document("test_doc1") + + retrieved_doc = vector_store.get_document("test_doc1") + assert retrieved_doc is None + + +@pytest.mark.unit +def test_retrieve_similar_documents(vector_store): + doc1 = Document(id="doc1", content="Sample document content about testing.", metadata={"category": "sample"}) + doc2 = Document(id="doc2", content="Another test document for retrieval.", metadata={"category": "sample"}) + + vector_store.add_document(doc1) + vector_store.add_document(doc2) + + similar_docs = vector_store.retrieve("test document", top_k=2) + + assert len(similar_docs) == 2 + assert similar_docs[0].id == "doc1" or similar_docs[0].id == "doc2" + assert similar_docs[1].id == "doc1" or similar_docs[1].id == "doc2"