Skip to content

Commit

Permalink
Merge pull request #550 from MichaelDecent/CloudWeaviate
Browse files Browse the repository at this point in the history
Cloud weaviate VectorStore
  • Loading branch information
cobycloud authored Sep 26, 2024
2 parents 006856c + e439708 commit 0c8bac5
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import List, Union, Literal
from weaviate.classes.query import MetadataQuery
import uuid as ud
import weaviate
import weaviate.classes as wvc
from weaviate.classes.init import Auth
from weaviate.util import generate_uuid5
# from weaviate import Client, AuthApiKey
from swarmauri.vectors.concrete.Vector import Vector

from swarmauri.documents.concrete.Document import Document # Replace with your actual import
from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding # Replace with your actual import


class CloudWeaviateVectorStore():
"""
CloudWeaviateVectorStore is a concrete implementation that integrates functionality
for saving, loading, storing, and retrieving vector documents, leveraging Weaviate as the backend.
"""
type: Literal["CloudWeaviateVectorStore"] = "CloudWeaviateVectorStore"

def __init__(self, url: str, api_key: str, collection_name: str, vector_size: int, **kwargs):
self.url = url
self.api_key = api_key
self.collection_name = collection_name
self.vector_size = vector_size

self._embedder = Doc2VecEmbedding(vector_size=vector_size)
self.vectorizer = self._embedder

self.namespace_uuid = ud.uuid4()

# Initialize Weaviate client with v4 authentication
self.client = weaviate.connect_to_weaviate_cloud (
cluster_url=self.url,
auth_credentials=Auth.api_key(self.api_key),
headers=kwargs.get("headers", {})
)


def add_document(self, document: Document) -> None:
"""
Add a single document to the vector store.
"""
try:
jeopardy = self.client.collections.get(self.collection_name)

if not document.embedding:
embedding = self.vectorizer.fit_transform([document.content])[0]
else:
embedding = document.embedding

data_object = {
"content": document.content,
"metadata": document.metadata,

}

uuid = jeopardy.data.insert(
properties=data_object,
vector=embedding.value,
uuid=str(ud.uuid5(self.namespace_uuid, document.id)) if document.id else generate_uuid5(data_object)
)

print(f"Document '{document.id}' added to Weaviate.")
except Exception as e:
print(f"Error adding document '{document.id}': {e}")
raise

def add_documents(self, documents: List[Document]) -> None:
"""
Add multiple documents to the vector store in a batch.
"""
try:
for document in documents:
self.add_document(document)

print(f"{len(documents)} documents added to Weaviate.")
except Exception as e:
print(f"Error adding documents: {e}")
raise

def get_document(self, id: str) -> Union[Document, None]:
"""
Retrieve a single document by its identifier.
"""
try:
jeopardy = self.client.collections.get(self.collection_name)

result = jeopardy.query.fetch_object_by_id(ud.uuid5(self.namespace_uuid, id))

if result:

return Document(

content=result.properties["content"],
metadata=result.properties["metadata"],

)
return None
except Exception as e:
print(f"Error retrieving document '{id}': {e}")
return None

def get_all_documents(self) -> List[Document]:
"""
Retrieve all documents from the vector store.
"""
try:
collection = self.client.collections.get(self.collection_name)

documents = [Document(

content=item.properties["content"],
metadata=item.properties["metadata"],
embedding=Vector(value=list(item.vector.values())[0])
) for item in collection.iterator(include_vector=True)]
print(documents[0])
return documents
except Exception as e:
print(f"Error retrieving all documents: {e}")
return []

def delete_document(self, id: str) -> None:
"""
Delete a document from the vector store by its identifier.
"""
try:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_by_id(ud.uuid5(self.namespace_uuid, id))
print(f"Document '{id}' has been deleted from Weaviate.")
except Exception as e:
print(f"Error deleting document '{id}': {e}")
raise

def update_document(self, document: Document) -> None:
self.delete_document(id)
self.add_document(document)

def document_count(self) -> int:
"""
Returns the number of documents in the store.
"""
try:
result = self.client.query.aggregate(self.collection_name).with_meta_count().do()
count = result["data"]["Aggregate"][self.collection_name][0]["meta"]["count"]
return count
except Exception as e:
print(f"Error counting documents: {e}")
return 0

def retrieve(self, query: str, top_k: int = 5) -> List[Document]:
"""
Retrieve the top_k most relevant documents based on the given query.
"""
try:
jeopardy = self.client.collections.get(self.collection_name)
query_vector = self.vectorizer.infer_vector(query)
response = jeopardy.query.near_vector(
near_vector=query_vector.value, # your query vector goes here
limit=top_k,
return_metadata=MetadataQuery(distance=True)
)

documents = [
Document(
content=res.properties["content"],
metadata=res.properties["metadata"]
) for res in response.objects
]
return documents
except Exception as e:
print(f"Error retrieving documents for query '{query}': {e}")
return []

def close(self):
"""
Close the connection to the Weaviate server.
"""
try:
self.client.close()
except Exception as e:
print(f"Error closing connection: {e}")
raise
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import pytest
from swarmauri.documents.concrete.Document import Document
from swarmauri_community.vector_stores.CloudWeaviateVectorStore import (
CloudWeaviateVectorStore,
)
from dotenv import load_dotenv

load_dotenv()

WEAVIATE_URL = os.getenv("WEAVIATE_URL")
WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY")


@pytest.fixture(scope="module")
def vector_store():
if not all([WEAVIATE_URL, WEAVIATE_API_KEY]):
pytest.skip("Skipping due to environment variable not set")
vs = CloudWeaviateVectorStore(
url=WEAVIATE_URL,
api_key=WEAVIATE_API_KEY,
collection_name="example",
vector_size=100,
)
return vs


@pytest.mark.unit
def test_ubc_type(vector_store):
assert vector_store.type == "CloudWeaviateVectorStore"

@pytest.mark.unit
def test_ubc_resource(vector_store):
assert vector_store.resource == "VectorStore"
assert vector_store.embedder.resource == "Embedding"


@pytest.mark.unit
def test_serialization(vector_store):
"""
Test to verify serialization and deserialization of Neo4jVectorStore.
"""
assert (
vector_store.id
== CloudWeaviateVectorStore.model_validate_json(
vector_store.model_dump_json()
).id
)


@pytest.mark.unit
def test_top_k(vector_store):

document1 = Document(
id="doc-001",
content="This is the content of the first document.",
metadata={"author": "Alice", "date": "2024-09-25"},
)
document2 = Document(
id="doc-002",
content="The second document contains different information.",
metadata={"author": "Bob", "date": "2024-09-26"},
)

vector_store.add_document(document1)
vector_store.add_document(document2)
assert len(vector_store.retrieve(query="information", top_k=1)) == 1

0 comments on commit 0c8bac5

Please sign in to comment.