Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ishaan1234/swarmauri-sdk
Browse files Browse the repository at this point in the history
…into 0.4.5.dev1
  • Loading branch information
cobycloud committed Sep 24, 2024
2 parents 830ca1d + c6ea66b commit 90f6e5c
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Union, Literal
from pydantic import BaseModel, PrivateAttr
from typing import List, Union, Literal, Optional
from pydantic import BaseModel, PrivateAttr, field_validator
from neo4j import GraphDatabase
import json

Expand All @@ -11,69 +11,72 @@

class Neo4jVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase, BaseModel):
type: Literal['Neo4jVectorStore'] = 'Neo4jVectorStore'

# Private attributes
_driver: PrivateAttr = None

def __init__(self, uri: str, user: str, password: str, **kwargs):
"""
Initialize the Neo4jVectorStore.
uri: str
user: str
password: str
collection_name: Optional[str] = None

:param uri: Neo4j database URI, e.g., "bolt://localhost:7687"
:param user: Username for Neo4j
:param password: Password for Neo4j
:param kwargs: Additional arguments
"""
super().__init__(**kwargs)
self._driver = GraphDatabase.driver(uri, auth=(user, password))
# Private attributes are excluded from serialization by default
_driver: Optional[GraphDatabase.driver] = PrivateAttr(default=None)

def __init__(self, **data):
super().__init__(**data)

self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
self._initialize_schema()

def _initialize_schema(self):
"""
Initialize the Neo4j schema, creating necessary indexes and constraints.
"""

with self._driver.session() as session:
# Create a unique constraint on Document ID with a specific constraint name
session.run("""
CREATE CONSTRAINT unique_document_id IF NOT EXISTS
FOR (d:Document)
REQUIRE d.id IS UNIQUE
""")
CREATE CONSTRAINT unique_document_id IF NOT EXISTS
FOR (d:Document)
REQUIRE d.id IS UNIQUE
""")

def add_document(self, document: Document) -> None:
"""
Add a single document to the Neo4j store.
:param document: Document to add
"""

with self._driver.session() as session:
session.run("""
MERGE (d:Document {id: $id})
SET d.content = $content,
d.metadata = $metadata
""", id=document.id, content=document.content, metadata=json.dumps(document.metadata))



def add_documents(self, documents: List[Document]) -> None:
"""
Add multiple documents to the Neo4j store.
:param documents: List of documents to add
"""

with self._driver.session() as session:
for document in documents:
session.run("""
MERGE (d:Document {id: $id})
SET d.content = $content,
d.metadata = $metadata
""", id=document.id, content=document.content, metadata=json.dumps(document.metadata))



def get_document(self, id: str) -> Union[Document, None]:
"""
Retrieve a document by its ID.
:param id: Document ID
:return: Document object or None if not found
"""

with self._driver.session() as session:
result = session.run("""
MATCH (d:Document {id: $id})
Expand All @@ -86,7 +89,8 @@ def get_document(self, id: str) -> Union[Document, None]:
metadata=json.loads(result['metadata'])
)
return None



def get_all_documents(self) -> List[Document]:
"""
Retrieve all documents from the Neo4j store.
Expand All @@ -106,33 +110,37 @@ def get_all_documents(self) -> List[Document]:
metadata=json.loads(record['metadata'])
))
return documents



def delete_document(self, id: str) -> None:
"""
Delete a document by its ID.
:param id: Document ID
"""

with self._driver.session() as session:
session.run("""
MATCH (d:Document {id: $id})
DETACH DELETE d
""", id=id)

def update_document(self, id: str, updated_document: Document) -> None:
"""
Update an existing document.
:param id: Document ID
:param updated_document: Document object with updated data
"""

with self._driver.session() as session:
session.run("""
MATCH (d:Document {id: $id})
SET d.content = $content,
d.metadata = $metadata
""", id=id, content=updated_document.content, metadata=json.dumps(updated_document.metadata))


def retrieve(self, query: str, top_k: int = 5, string_field: str = 'content') -> List[Document]:
"""
Retrieve the top_k most similar documents to the query based on Levenshtein distance using APOC's apoc.text.distance.
Expand All @@ -142,17 +150,19 @@ def retrieve(self, query: str, top_k: int = 5, string_field: str = 'content') ->
:param string_field: Specific field to apply Levenshtein distance (default: 'content')
:return: List of Document objects
"""

input_text = query

with self._driver.session() as session:
cypher_query = f"""
MATCH (d:Document)
RETURN d.id AS id, d.content AS content, d.metadata AS metadata,
apoc.text.distance(d.{string_field}, $input_text) AS distance
apoc.text.distance(d.{string_field}, $input_text) AS distance
ORDER BY distance ASC
LIMIT $top_k
"""
results = session.run(cypher_query, input_text=input_text, top_k=top_k)

documents = []
for record in results:
documents.append(Document(
Expand All @@ -161,14 +171,32 @@ def retrieve(self, query: str, top_k: int = 5, string_field: str = 'content') ->
metadata=json.loads(record['metadata'])
))
return documents

def close(self):
"""
Close the Neo4j driver connection.
"""

if self._driver:
self._driver.close()

def __del__(self):
self.close()



# Serialization methods
def serialize(self) -> str:
"""
Serialize the configuration of the store.
"""
return self.model_dump_json()

@classmethod
def deserialize(cls, json_str: str) -> 'Neo4jVectorStore':
"""
Deserialize the JSON string to create a new instance of Neo4jVectorStore.
"""
data = json.loads(json_str)
return cls(
uri=data['uri'],
user=data['user'],
password=data['password'],
collection_name=data.get('collection_name')
)
73 changes: 58 additions & 15 deletions pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from Neo4jVectorStore import Neo4jVectorStore
from neo4j import GraphDatabase


# Load environment variables
load_dotenv()
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
Expand All @@ -18,24 +18,38 @@
@pytest.fixture(scope="module")
def vector_store():
"""
Fixture to initialize and teardown the Neo4jVectorStoreLevenshtein.
Fixture to initialize and teardown the Neo4jVectorStore.
"""
if not all([NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD]):
pytest.fail("NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set in environment variables.")

store = Neo4jVectorStore(
uri=NEO4J_URI,
user=NEO4J_USER,
password=NEO4J_PASSWORD,
collection_name=COLLECTION_NAME # If applicable
)
yield store
# # Teardown: Clean up the test collection if necessary
# with store._driver.session() as session:
# session.run("""
# MATCH (d:Document)
# DETACH DELETE d
# """)
# Teardown: Clean up the test collection if necessary
try:
with store._driver.session() as session:
session.run("""
MATCH (d:Document)
DETACH DELETE d
""")
except Exception as e:
pytest.fail(f"Teardown failed: {e}")
store.close()


@pytest.fixture
def sample_documents():
return [
Document(id="doc_sample_1", content="Sample Content 1", metadata={"key": "value1"}),
Document(id="doc_sample_2", content="Sample Content 2", metadata={"key": "value2"}),
]


@pytest.mark.unit
def test_ubc_type(vector_store):
"""
Expand All @@ -49,18 +63,47 @@ def test_serialization(vector_store):
"""
Test to verify serialization and deserialization of Neo4jVectorStore.
"""

assert vector_store.id == Neo4jVectorStore.model_validate_json(vector_store.model_dump_json()).id
serialized = vector_store.serialize()
deserialized_store = Neo4jVectorStore.deserialize(serialized)
assert vector_store.uri == deserialized_store.uri
assert vector_store.user == deserialized_store.user
assert vector_store.password == deserialized_store.password
assert vector_store.collection_name == deserialized_store.collection_name


@pytest.mark.unit
def top_k_test(vector_store):
documents = [
Document(content="test"),
Document(content="test1"),
Document(content="test2"),
Document(content="test3"),
Document(id="doc_test_1", content="test", metadata={}),
Document(id="doc_test_2", content="test1", metadata={}),
Document(id="doc_test_3", content="test2", metadata={}),
Document(id="doc_test_4", content="test3", metadata={}),
]

vector_store.add_documents(documents)
assert len(vector_store.retrieve(query="test", top_k=2)) == 2
retrieved_docs = vector_store.retrieve(query="test", top_k=2)
assert len(retrieved_docs) == 2
# Assuming 'doc_test_1' and 'doc_test_2' are the most similar
expected_ids = {"doc_test_1", "doc_test_2"}
retrieved_ids = {doc.id for doc in retrieved_docs}
assert retrieved_ids == expected_ids


@pytest.mark.unit
def test_add_document(vector_store):
doc = Document(
id="doc_add_1",
content="This is a sample document.",
metadata={"author": "John Doe", "title": "Sample Document"}
)
vector_store.add_document(doc)
retrieved_doc = vector_store.get_document("doc_add_1")
assert retrieved_doc == doc


@pytest.mark.unit
def test_add_documents(vector_store, sample_documents):
vector_store.add_documents(sample_documents)
for doc in sample_documents:
retrieved_doc = vector_store.get_document(doc.id)
assert retrieved_doc == doc

0 comments on commit 90f6e5c

Please sign in to comment.