From e3f4b2849078a29835d5ff07c09980d2e9ea893e Mon Sep 17 00:00:00 2001 From: ishaan1234 <31759888+ishaan1234@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:03:20 +0530 Subject: [PATCH 1/4] Add files via upload --- .../vector_stores/CloudWeaviateVectorStore.py | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py new file mode 100644 index 000000000..b976d5ccd --- /dev/null +++ b/pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py @@ -0,0 +1,184 @@ +from typing import List, Union, Literal +from weaviate.classes.query import MetadataQuery +import uuid as ud +import weaviate +import weaviate.classes as wvc +from weaviate.classes.init import Auth +from weaviate.util import generate_uuid5 +# from weaviate import Client, AuthApiKey +from swarmauri.vectors.concrete.Vector import Vector + +from swarmauri.documents.concrete.Document import Document # Replace with your actual import +from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding # Replace with your actual import + + +class CloudWeaviateVectorStore(): + """ + CloudWeaviateVectorStore is a concrete implementation that integrates functionality + for saving, loading, storing, and retrieving vector documents, leveraging Weaviate as the backend. + """ + type: Literal["CloudWeaviateVectorStore"] = "CloudWeaviateVectorStore" + + def __init__(self, url: str, api_key: str, collection_name: str, vector_size: int, **kwargs): + self.url = url + self.api_key = api_key + self.collection_name = collection_name + self.vector_size = vector_size + + self._embedder = Doc2VecEmbedding(vector_size=vector_size) + self.vectorizer = self._embedder + + self.namespace_uuid = ud.uuid4() + + # Initialize Weaviate client with v4 authentication + self.client = weaviate.connect_to_weaviate_cloud ( + cluster_url=self.url, + auth_credentials=Auth.api_key(self.api_key), + headers=kwargs.get("headers", {}) + ) + + + def add_document(self, document: Document) -> None: + """ + Add a single document to the vector store. + """ + try: + jeopardy = self.client.collections.get(self.collection_name) + + if not document.embedding: + embedding = self.vectorizer.fit_transform([document.content])[0] + else: + embedding = document.embedding + + data_object = { + "content": document.content, + "metadata": document.metadata, + + } + + uuid = jeopardy.data.insert( + properties=data_object, + vector=embedding.value, + uuid=str(ud.uuid5(self.namespace_uuid, document.id)) if document.id else generate_uuid5(data_object) + ) + + print(f"Document '{document.id}' added to Weaviate.") + except Exception as e: + print(f"Error adding document '{document.id}': {e}") + raise + + def add_documents(self, documents: List[Document]) -> None: + """ + Add multiple documents to the vector store in a batch. + """ + try: + for document in documents: + self.add_document(document) + + print(f"{len(documents)} documents added to Weaviate.") + except Exception as e: + print(f"Error adding documents: {e}") + raise + + def get_document(self, id: str) -> Union[Document, None]: + """ + Retrieve a single document by its identifier. + """ + try: + jeopardy = self.client.collections.get(self.collection_name) + + result = jeopardy.query.fetch_object_by_id(ud.uuid5(self.namespace_uuid, id)) + + if result: + + return Document( + + content=result.properties["content"], + metadata=result.properties["metadata"], + + ) + return None + except Exception as e: + print(f"Error retrieving document '{id}': {e}") + return None + + def get_all_documents(self) -> List[Document]: + """ + Retrieve all documents from the vector store. + """ + try: + collection = self.client.collections.get(self.collection_name) + + documents = [Document( + + content=item.properties["content"], + metadata=item.properties["metadata"], + embedding=Vector(value=list(item.vector.values())[0]) + ) for item in collection.iterator(include_vector=True)] + print(documents[0]) + return documents + except Exception as e: + print(f"Error retrieving all documents: {e}") + return [] + + def delete_document(self, id: str) -> None: + """ + Delete a document from the vector store by its identifier. + """ + try: + collection = self.client.collections.get(self.collection_name) + collection.data.delete_by_id(ud.uuid5(self.namespace_uuid, id)) + print(f"Document '{id}' has been deleted from Weaviate.") + except Exception as e: + print(f"Error deleting document '{id}': {e}") + raise + + def update_document(self, document: Document) -> None: + self.delete_document(id) + self.add_document(document) + + def document_count(self) -> int: + """ + Returns the number of documents in the store. + """ + try: + result = self.client.query.aggregate(self.collection_name).with_meta_count().do() + count = result["data"]["Aggregate"][self.collection_name][0]["meta"]["count"] + return count + except Exception as e: + print(f"Error counting documents: {e}") + return 0 + + def retrieve(self, query: str, top_k: int = 5) -> List[Document]: + """ + Retrieve the top_k most relevant documents based on the given query. + """ + try: + jeopardy = self.client.collections.get(self.collection_name) + query_vector = self.vectorizer.infer_vector(query) + response = jeopardy.query.near_vector( + near_vector=query_vector.value, # your query vector goes here + limit=top_k, + return_metadata=MetadataQuery(distance=True) + ) + + documents = [ + Document( + content=res.properties["content"], + metadata=res.properties["metadata"] + ) for res in response.objects + ] + return documents + except Exception as e: + print(f"Error retrieving documents for query '{query}': {e}") + return [] + + def close(self): + """ + Close the connection to the Weaviate server. + """ + try: + self.client.close() + except Exception as e: + print(f"Error closing connection: {e}") + raise From 50d7a4ed75dc22d43e9efbff6fa5c18269876ca3 Mon Sep 17 00:00:00 2001 From: ishaan1234 <31759888+ishaan1234@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:04:33 +0530 Subject: [PATCH 2/4] Add files via upload --- .../CloudWeaviateVectorStore_test.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py diff --git a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py new file mode 100644 index 000000000..c2029481d --- /dev/null +++ b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py @@ -0,0 +1,51 @@ +import os +import pytest +from swarmauri.documents.concrete.Document import Document +from swarmauri_community.vector_stores.CloudWeaviateVectorStore import CloudWeaviateVectorStore + +WEAVIATE_URL = "https://p6grmuovrkqie6kafxts2a.c0.asia-southeast1.gcp.weaviate.cloud" +WEAVIATE_API_KEY ="kAF7ar7sZqgFyZEhS4hL9eVAJ3Br5PwJP6An" + + +@pytest.mark.skipif( + not WEAVIATE_URL or not WEAVIATE_API_KEY, + reason="Skipping due to environment variables not set", +) +@pytest.mark.unit +def test_weaviate_type(): + vs = CloudWeaviateVectorStore( + url=WEAVIATE_URL, + api_key=WEAVIATE_API_KEY, + collection_name="example", + vector_size=100, + ) + assert vs.type == "CloudWeaviateVectorStore" + + + +@pytest.mark.skipif( + not WEAVIATE_URL or not WEAVIATE_API_KEY, + reason="Skipping due to environment variables not set", +) +@pytest.mark.unit +def test_top_k(): + vs = CloudWeaviateVectorStore( + url=WEAVIATE_URL, + api_key=WEAVIATE_API_KEY, + collection_name="example", + vector_size=100, + ) + document1 = Document( + id="doc-001", + content="This is the content of the first document.", + metadata={"author": "Alice", "date": "2024-09-25"}, + ) + document2 = Document( + id="doc-002", + content="The second document contains different information.", + metadata={"author": "Bob", "date": "2024-09-26"}, + ) + + vs.add_document(document1) + vs.add_document(document2) + assert len(vs.retrieve(query="information", top_k=1)) == 1 From 62e61c78c39ac67ba462d31322c50062b99f072d Mon Sep 17 00:00:00 2001 From: michaeldecent2 <111002205+MichaelDecent@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:51:39 +0100 Subject: [PATCH 3/4] added all basic test files to CloudWeaviateVectorStore_test.py --- .../CloudWeaviateVectorStore_test.py | 65 ++++++++++++------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py index c2029481d..524572be7 100644 --- a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py @@ -1,40 +1,59 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.CloudWeaviateVectorStore import CloudWeaviateVectorStore - -WEAVIATE_URL = "https://p6grmuovrkqie6kafxts2a.c0.asia-southeast1.gcp.weaviate.cloud" -WEAVIATE_API_KEY ="kAF7ar7sZqgFyZEhS4hL9eVAJ3Br5PwJP6An" +from swarmauri_community.vector_stores.CloudWeaviateVectorStore import ( + CloudWeaviateVectorStore, +) +from dotenv import load_dotenv +load_dotenv() -@pytest.mark.skipif( - not WEAVIATE_URL or not WEAVIATE_API_KEY, - reason="Skipping due to environment variables not set", +WEAVIATE_URL = os.getenv( + "WEAVIATE_URL", + "https://p6grmuovrkqie6kafxts2a.c0.asia-southeast1.gcp.weaviate.cloud", ) -@pytest.mark.unit -def test_weaviate_type(): +WEAVIATE_API_KEY = os.getenv("WEAVIATE_URL", "kAF7ar7sZqgFyZEhS4hL9eVAJ3Br5PwJP6An") + + +@pytest.fixture(scope="module") +def vector_store(): + if not all([WEAVIATE_URL, WEAVIATE_API_KEY]): + pytest.skip("Skipping due to environment variable not set") vs = CloudWeaviateVectorStore( url=WEAVIATE_URL, api_key=WEAVIATE_API_KEY, collection_name="example", vector_size=100, ) - assert vs.type == "CloudWeaviateVectorStore" + return vs +@pytest.mark.unit +def test_ubc_type(vector_store): + assert vector_store.type == "CloudWeaviateVectorStore" -@pytest.mark.skipif( - not WEAVIATE_URL or not WEAVIATE_API_KEY, - reason="Skipping due to environment variables not set", -) @pytest.mark.unit -def test_top_k(): - vs = CloudWeaviateVectorStore( - url=WEAVIATE_URL, - api_key=WEAVIATE_API_KEY, - collection_name="example", - vector_size=100, +def test_ubc_resource(vector_store): + assert vector_store.resource == "VectorStore" + assert vector_store.embedder.resource == "Embedding" + + +@pytest.mark.unit +def test_serialization(vector_store): + """ + Test to verify serialization and deserialization of Neo4jVectorStore. + """ + assert ( + vector_store.id + == CloudWeaviateVectorStore.model_validate_json( + vector_store.model_dump_json() + ).id ) + + +@pytest.mark.unit +def test_top_k(vector_store): + document1 = Document( id="doc-001", content="This is the content of the first document.", @@ -46,6 +65,6 @@ def test_top_k(): metadata={"author": "Bob", "date": "2024-09-26"}, ) - vs.add_document(document1) - vs.add_document(document2) - assert len(vs.retrieve(query="information", top_k=1)) == 1 + vector_store.add_document(document1) + vector_store.add_document(document2) + assert len(vector_store.retrieve(query="information", top_k=1)) == 1 From e4397081f4825945148cc45aeeb8ed2c77ebc4a2 Mon Sep 17 00:00:00 2001 From: cobycloud <25079070+cobycloud@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:07:45 -0500 Subject: [PATCH 4/4] Update CloudWeaviateVectorStore_test.py --- .../unit/vector_stores/CloudWeaviateVectorStore_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py index 524572be7..9669e33cd 100644 --- a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py @@ -8,11 +8,8 @@ load_dotenv() -WEAVIATE_URL = os.getenv( - "WEAVIATE_URL", - "https://p6grmuovrkqie6kafxts2a.c0.asia-southeast1.gcp.weaviate.cloud", -) -WEAVIATE_API_KEY = os.getenv("WEAVIATE_URL", "kAF7ar7sZqgFyZEhS4hL9eVAJ3Br5PwJP6An") +WEAVIATE_URL = os.getenv("WEAVIATE_URL") +WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY") @pytest.fixture(scope="module")