Skip to content

Commit

Permalink
Merge pull request #528 from MichaelDecent/0.4.5.dev1
Browse files Browse the repository at this point in the history
fixed collection errors in community
  • Loading branch information
cobycloud authored Sep 24, 2024
2 parents 90f6e5c + e1b4f20 commit f81c257
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 132 deletions.
7 changes: 4 additions & 3 deletions pkgs/community/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
long_description=swarmauri_community.__long_desc__,
long_description_content_type="text/markdown",
url="http://github.com/swarmauri/swarmauri-sdk",
license='Apache Software License',
license="Apache Software License",
packages=find_packages(
include=["swarmauri_community*"]
), # Include packages in your_package and libs directories
Expand All @@ -19,8 +19,9 @@
"requests",
"pydantic",
"pymupdf",
"neo4j",
"swarmauri-core==0.5.0.dev8",
"swarmauri==0.5.0.dev8"
"swarmauri==0.5.0.dev8",
],
extras_require={
"full": [
Expand Down Expand Up @@ -62,7 +63,7 @@
"pacmap",
"tf-keras",
"pinecone",
"neo4j"
"neo4j",
]
},
classifiers=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class MutualInformationMetric(MetricBase, MetricCalculateMixin):
"""

type: Literal["MutualInformationMetric"] = "MutualInformationMetric"
unit: str = "bits"

def calculate(self, data: pd.DataFrame, target_column: str) -> float:
"""
Expand Down
135 changes: 81 additions & 54 deletions pkgs/community/swarmauri_community/vector_stores/Neo4jVectorStore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
from neo4j import GraphDatabase
import json

from swarmauri.standard.documents.concrete.Document import Document
from swarmauri.standard.vector_stores.base.VectorStoreBase import VectorStoreBase
from swarmauri.standard.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin
from swarmauri.standard.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin
from swarmauri.documents.concrete.Document import Document
from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase
from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import (
VectorStoreRetrieveMixin,
)
from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import (
VectorStoreSaveLoadMixin,
)



class Neo4jVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase, BaseModel):
Expand All @@ -32,11 +37,15 @@ def _initialize_schema(self):

with self._driver.session() as session:
# Create a unique constraint on Document ID with a specific constraint name
session.run("""
CREATE CONSTRAINT unique_document_id IF NOT EXISTS
FOR (d:Document)
REQUIRE d.id IS UNIQUE
""")

session.run(
"""
CREATE CONSTRAINT unique_document_id IF NOT EXISTS
FOR (d:Document)
REQUIRE d.id IS UNIQUE
"""
)


def add_document(self, document: Document) -> None:
"""
Expand All @@ -46,11 +55,17 @@ def add_document(self, document: Document) -> None:
"""

with self._driver.session() as session:
session.run("""
session.run(
"""
MERGE (d:Document {id: $id})
SET d.content = $content,
d.metadata = $metadata
""", id=document.id, content=document.content, metadata=json.dumps(document.metadata))
""",
id=document.id,
content=document.content,
metadata=json.dumps(document.metadata),
)


def add_documents(self, documents: List[Document]) -> None:
Expand All @@ -62,11 +77,17 @@ def add_documents(self, documents: List[Document]) -> None:

with self._driver.session() as session:
for document in documents:
session.run("""
session.run(
"""
MERGE (d:Document {id: $id})
SET d.content = $content,
d.metadata = $metadata
""", id=document.id, content=document.content, metadata=json.dumps(document.metadata))
""",
id=document.id,
content=document.content,
metadata=json.dumps(document.metadata),
)


def get_document(self, id: str) -> Union[Document, None]:
Expand All @@ -78,15 +99,18 @@ def get_document(self, id: str) -> Union[Document, None]:
"""

with self._driver.session() as session:
result = session.run("""
result = session.run(
"""
MATCH (d:Document {id: $id})
RETURN d.id AS id, d.content AS content, d.metadata AS metadata
""", id=id).single()
""",
id=id,
).single()
if result:
return Document(
id=result['id'],
content=result['content'],
metadata=json.loads(result['metadata'])
id=result["id"],
content=result["content"],
metadata=json.loads(result["metadata"]),
)
return None

Expand All @@ -98,17 +122,21 @@ def get_all_documents(self) -> List[Document]:
:return: List of Document objects
"""
with self._driver.session() as session:
results = session.run("""
results = session.run(
"""
MATCH (d:Document)
RETURN d.id AS id, d.content AS content, d.metadata AS metadata
""")
"""
)
documents = []
for record in results:
documents.append(Document(
id=record['id'],
content=record['content'],
metadata=json.loads(record['metadata'])
))
documents.append(
Document(
id=record["id"],
content=record["content"],
metadata=json.loads(record["metadata"]),
)
)
return documents


Expand All @@ -120,10 +148,14 @@ def delete_document(self, id: str) -> None:
"""

with self._driver.session() as session:
session.run("""
session.run(
"""
MATCH (d:Document {id: $id})
DETACH DELETE d
""", id=id)
""",
id=id,
)


def update_document(self, id: str, updated_document: Document) -> None:
"""
Expand All @@ -134,14 +166,22 @@ def update_document(self, id: str, updated_document: Document) -> None:
"""

with self._driver.session() as session:
session.run("""
session.run(
"""
MATCH (d:Document {id: $id})
SET d.content = $content,
d.metadata = $metadata
""", id=id, content=updated_document.content, metadata=json.dumps(updated_document.metadata))

def retrieve(self, query: str, top_k: int = 5, string_field: str = 'content') -> List[Document]:
""",
id=id,
content=updated_document.content,
metadata=json.dumps(updated_document.metadata),
)

def retrieve(
self, query: str, top_k: int = 5, string_field: str = "content"
) -> List[Document]:

"""
Retrieve the top_k most similar documents to the query based on Levenshtein distance using APOC's apoc.text.distance.
Expand All @@ -165,11 +205,13 @@ def retrieve(self, query: str, top_k: int = 5, string_field: str = 'content') ->

documents = []
for record in results:
documents.append(Document(
id=record['id'],
content=record['content'],
metadata=json.loads(record['metadata'])
))
documents.append(
Document(
id=record["id"],
content=record["content"],
metadata=json.loads(record["metadata"]),
)
)
return documents

def close(self):
Expand All @@ -181,22 +223,7 @@ def close(self):
self._driver.close()


# Serialization methods
def serialize(self) -> str:
"""
Serialize the configuration of the store.
"""
return self.model_dump_json()
def __del__(self):
self.close()


@classmethod
def deserialize(cls, json_str: str) -> 'Neo4jVectorStore':
"""
Deserialize the JSON string to create a new instance of Neo4jVectorStore.
"""
data = json.loads(json_str)
return cls(
uri=data['uri'],
user=data['user'],
password=data['password'],
collection_name=data.get('collection_name')
)
Loading

0 comments on commit f81c257

Please sign in to comment.