-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #550 from MichaelDecent/CloudWeaviate
Cloud weaviate VectorStore
- Loading branch information
Showing
2 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
184 changes: 184 additions & 0 deletions
184
pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import List, Union, Literal | ||
from weaviate.classes.query import MetadataQuery | ||
import uuid as ud | ||
import weaviate | ||
import weaviate.classes as wvc | ||
from weaviate.classes.init import Auth | ||
from weaviate.util import generate_uuid5 | ||
# from weaviate import Client, AuthApiKey | ||
from swarmauri.vectors.concrete.Vector import Vector | ||
|
||
from swarmauri.documents.concrete.Document import Document # Replace with your actual import | ||
from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding # Replace with your actual import | ||
|
||
|
||
class CloudWeaviateVectorStore(): | ||
""" | ||
CloudWeaviateVectorStore is a concrete implementation that integrates functionality | ||
for saving, loading, storing, and retrieving vector documents, leveraging Weaviate as the backend. | ||
""" | ||
type: Literal["CloudWeaviateVectorStore"] = "CloudWeaviateVectorStore" | ||
|
||
def __init__(self, url: str, api_key: str, collection_name: str, vector_size: int, **kwargs): | ||
self.url = url | ||
self.api_key = api_key | ||
self.collection_name = collection_name | ||
self.vector_size = vector_size | ||
|
||
self._embedder = Doc2VecEmbedding(vector_size=vector_size) | ||
self.vectorizer = self._embedder | ||
|
||
self.namespace_uuid = ud.uuid4() | ||
|
||
# Initialize Weaviate client with v4 authentication | ||
self.client = weaviate.connect_to_weaviate_cloud ( | ||
cluster_url=self.url, | ||
auth_credentials=Auth.api_key(self.api_key), | ||
headers=kwargs.get("headers", {}) | ||
) | ||
|
||
|
||
def add_document(self, document: Document) -> None: | ||
""" | ||
Add a single document to the vector store. | ||
""" | ||
try: | ||
jeopardy = self.client.collections.get(self.collection_name) | ||
|
||
if not document.embedding: | ||
embedding = self.vectorizer.fit_transform([document.content])[0] | ||
else: | ||
embedding = document.embedding | ||
|
||
data_object = { | ||
"content": document.content, | ||
"metadata": document.metadata, | ||
|
||
} | ||
|
||
uuid = jeopardy.data.insert( | ||
properties=data_object, | ||
vector=embedding.value, | ||
uuid=str(ud.uuid5(self.namespace_uuid, document.id)) if document.id else generate_uuid5(data_object) | ||
) | ||
|
||
print(f"Document '{document.id}' added to Weaviate.") | ||
except Exception as e: | ||
print(f"Error adding document '{document.id}': {e}") | ||
raise | ||
|
||
def add_documents(self, documents: List[Document]) -> None: | ||
""" | ||
Add multiple documents to the vector store in a batch. | ||
""" | ||
try: | ||
for document in documents: | ||
self.add_document(document) | ||
|
||
print(f"{len(documents)} documents added to Weaviate.") | ||
except Exception as e: | ||
print(f"Error adding documents: {e}") | ||
raise | ||
|
||
def get_document(self, id: str) -> Union[Document, None]: | ||
""" | ||
Retrieve a single document by its identifier. | ||
""" | ||
try: | ||
jeopardy = self.client.collections.get(self.collection_name) | ||
|
||
result = jeopardy.query.fetch_object_by_id(ud.uuid5(self.namespace_uuid, id)) | ||
|
||
if result: | ||
|
||
return Document( | ||
|
||
content=result.properties["content"], | ||
metadata=result.properties["metadata"], | ||
|
||
) | ||
return None | ||
except Exception as e: | ||
print(f"Error retrieving document '{id}': {e}") | ||
return None | ||
|
||
def get_all_documents(self) -> List[Document]: | ||
""" | ||
Retrieve all documents from the vector store. | ||
""" | ||
try: | ||
collection = self.client.collections.get(self.collection_name) | ||
|
||
documents = [Document( | ||
|
||
content=item.properties["content"], | ||
metadata=item.properties["metadata"], | ||
embedding=Vector(value=list(item.vector.values())[0]) | ||
) for item in collection.iterator(include_vector=True)] | ||
print(documents[0]) | ||
return documents | ||
except Exception as e: | ||
print(f"Error retrieving all documents: {e}") | ||
return [] | ||
|
||
def delete_document(self, id: str) -> None: | ||
""" | ||
Delete a document from the vector store by its identifier. | ||
""" | ||
try: | ||
collection = self.client.collections.get(self.collection_name) | ||
collection.data.delete_by_id(ud.uuid5(self.namespace_uuid, id)) | ||
print(f"Document '{id}' has been deleted from Weaviate.") | ||
except Exception as e: | ||
print(f"Error deleting document '{id}': {e}") | ||
raise | ||
|
||
def update_document(self, document: Document) -> None: | ||
self.delete_document(id) | ||
self.add_document(document) | ||
|
||
def document_count(self) -> int: | ||
""" | ||
Returns the number of documents in the store. | ||
""" | ||
try: | ||
result = self.client.query.aggregate(self.collection_name).with_meta_count().do() | ||
count = result["data"]["Aggregate"][self.collection_name][0]["meta"]["count"] | ||
return count | ||
except Exception as e: | ||
print(f"Error counting documents: {e}") | ||
return 0 | ||
|
||
def retrieve(self, query: str, top_k: int = 5) -> List[Document]: | ||
""" | ||
Retrieve the top_k most relevant documents based on the given query. | ||
""" | ||
try: | ||
jeopardy = self.client.collections.get(self.collection_name) | ||
query_vector = self.vectorizer.infer_vector(query) | ||
response = jeopardy.query.near_vector( | ||
near_vector=query_vector.value, # your query vector goes here | ||
limit=top_k, | ||
return_metadata=MetadataQuery(distance=True) | ||
) | ||
|
||
documents = [ | ||
Document( | ||
content=res.properties["content"], | ||
metadata=res.properties["metadata"] | ||
) for res in response.objects | ||
] | ||
return documents | ||
except Exception as e: | ||
print(f"Error retrieving documents for query '{query}': {e}") | ||
return [] | ||
|
||
def close(self): | ||
""" | ||
Close the connection to the Weaviate server. | ||
""" | ||
try: | ||
self.client.close() | ||
except Exception as e: | ||
print(f"Error closing connection: {e}") | ||
raise |
67 changes: 67 additions & 0 deletions
67
pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import pytest | ||
from swarmauri.documents.concrete.Document import Document | ||
from swarmauri_community.vector_stores.CloudWeaviateVectorStore import ( | ||
CloudWeaviateVectorStore, | ||
) | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
WEAVIATE_URL = os.getenv("WEAVIATE_URL") | ||
WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY") | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def vector_store(): | ||
if not all([WEAVIATE_URL, WEAVIATE_API_KEY]): | ||
pytest.skip("Skipping due to environment variable not set") | ||
vs = CloudWeaviateVectorStore( | ||
url=WEAVIATE_URL, | ||
api_key=WEAVIATE_API_KEY, | ||
collection_name="example", | ||
vector_size=100, | ||
) | ||
return vs | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_type(vector_store): | ||
assert vector_store.type == "CloudWeaviateVectorStore" | ||
|
||
@pytest.mark.unit | ||
def test_ubc_resource(vector_store): | ||
assert vector_store.resource == "VectorStore" | ||
assert vector_store.embedder.resource == "Embedding" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_serialization(vector_store): | ||
""" | ||
Test to verify serialization and deserialization of Neo4jVectorStore. | ||
""" | ||
assert ( | ||
vector_store.id | ||
== CloudWeaviateVectorStore.model_validate_json( | ||
vector_store.model_dump_json() | ||
).id | ||
) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_top_k(vector_store): | ||
|
||
document1 = Document( | ||
id="doc-001", | ||
content="This is the content of the first document.", | ||
metadata={"author": "Alice", "date": "2024-09-25"}, | ||
) | ||
document2 = Document( | ||
id="doc-002", | ||
content="The second document contains different information.", | ||
metadata={"author": "Bob", "date": "2024-09-26"}, | ||
) | ||
|
||
vector_store.add_document(document1) | ||
vector_store.add_document(document2) | ||
assert len(vector_store.retrieve(query="information", top_k=1)) == 1 |