Skip to content

Commit

Permalink
Improve Aryn reader (#1172)
Browse files Browse the repository at this point in the history
* Ensure parent docs are collected during doc reconstruct

* mock using patch (#1160)

* Add a client for Aryn, use the new client to read docs

* Use list_docs and get_doc for reading from Aryn

* Mark Aryn classes experimental

* Fix lint

---------

Co-authored-by: Dhruv Kaliraman <[email protected]>
  • Loading branch information
austintlee and dhruvkaliraman7 authored Feb 14, 2025
1 parent 7f008bf commit 2227815
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 25 deletions.
131 changes: 113 additions & 18 deletions lib/sycamore/sycamore/connectors/aryn/ArynReader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import io
import json
import logging
import struct
from dataclasses import dataclass
from typing import Any
from time import time
from typing import Any, TYPE_CHECKING

import requests
from requests import Response
import httpx

from sycamore.connectors.aryn.client import ArynClient

from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data import Document
from sycamore.data.element import create_element
from sycamore.decorators import experimental

if TYPE_CHECKING:
from ray.data import Dataset

logger = logging.getLogger(__name__)


@dataclass
Expand Down Expand Up @@ -41,39 +52,123 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
return docs


class ArynClient(BaseDBReader.Client):
def __init__(self, client_params: ArynClientParams, **kwargs):
class ArynReaderClient(BaseDBReader.Client):
def __init__(self, client: ArynClient, client_params: ArynClientParams, **kwargs):
self.aryn_url = client_params.aryn_url
self.api_key = client_params.api_key
self._client = client
self.kwargs = kwargs

def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryResponse":
assert isinstance(query_params, ArynQueryParams)
headers = {"Authorization": f"Bearer {self.api_key}"}
response: Response = requests.post(
f"{self.aryn_url}/docsets/{query_params.docset_id}/read", stream=True, headers=headers
)
assert response.status_code == 200
docs = []
print(f"Reading from docset: {query_params.docset_id}")
for chunk in response.iter_lines():
# print(f"\n{chunk}\n")
doc = json.loads(chunk)
docs.append(doc)

client = httpx.Client()
with client.stream(
"POST", f"{self.aryn_url}/docsets/{query_params.docset_id}/read", headers=headers
) as response:

docs = []
print(f"Reading from docset: {query_params.docset_id}")
buffer = io.BytesIO()
to_read = 0
start_new_doc = True
doc_size_buf = bytearray(4)
idx = 0
chunk_count = 0
t0 = time()
for chunk in response.iter_bytes():
cur_pos = 0
chunk_count += 1
remaining = len(chunk)
print(f"Chunk {chunk_count} size: {len(chunk)}")
assert len(chunk) >= 4, f"Chunk too small: {len(chunk)} < 4"
while cur_pos < len(chunk):
if start_new_doc:
doc_size_buf[idx:] = chunk[cur_pos : cur_pos + 4 - idx]
to_read = struct.unpack("!i", doc_size_buf)[0]
print(f"Reading doc of size: {to_read}")
doc_size_buf = bytearray(4)
idx = 0
cur_pos += 4
remaining = len(chunk) - cur_pos
start_new_doc = False
if to_read > remaining:
buffer.write(chunk[cur_pos:])
to_read -= remaining
print(f"Remaining to read: {to_read}")
# Read the next chunk
break
else:
print("Reading the rest of the doc from the chunk")
buffer.write(chunk[cur_pos : cur_pos + to_read])
docs.append(json.loads(buffer.getvalue().decode()))
buffer.flush()
buffer.seek(0)
cur_pos += to_read
to_read = 0
start_new_doc = True
if (cur_pos - len(chunk)) < 4:
idx = left_over = cur_pos - len(chunk)
doc_size_buf[:left_over] = chunk[cur_pos:]
# Need to get the rest of the next chunk
break

t1 = time()
print(f"Reading took: {t1 - t0} seconds")
return ArynQueryResponse(docs)

def check_target_presence(self, query_params: "BaseDBReader.QueryParams") -> bool:
return True

@classmethod
def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynClient":
def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynReaderClient":
assert isinstance(params, ArynClientParams)
return cls(params)
client = ArynClient(params.aryn_url, params.api_key)
return cls(client, params)


@experimental
class ArynReader(BaseDBReader):
Client = ArynClient
Client = ArynReaderClient
Record = ArynQueryResponse
ClientParams = ArynClientParams
QueryParams = ArynQueryParams

def __init__(
self,
client_params: ArynClientParams,
query_params: ArynQueryParams,
**kwargs,
):
super().__init__(client_params=client_params, query_params=query_params, **kwargs)

def _to_doc(self, doc: dict[str, Any]) -> dict[str, Any]:
assert isinstance(self._client_params, ArynClientParams)
assert isinstance(self._query_params, ArynQueryParams)

client = self.Client.from_client_params(self._client_params)
aryn_client = client._client

doc = aryn_client.get_doc(self._query_params.docset_id, doc["doc_id"])
elements = doc.get("elements", [])
document = Document(**doc)
document.data["elements"] = [create_element(**element) for element in elements]
return {"doc": Document.serialize(document)}

def execute(self, **kwargs) -> "Dataset":

assert isinstance(self._client_params, ArynClientParams)
assert isinstance(self._query_params, ArynQueryParams)

client = self.Client.from_client_params(self._client_params)
aryn_client = client._client

# TODO paginate
docs = aryn_client.list_docs(self._query_params.docset_id)
logger.debug(f"Found {len(docs)} docs in docset: {self._query_params.docset_id}")

from ray.data import from_items

ds = from_items([{"doc_id": doc_id} for doc_id in docs])
return ds.map(self._to_doc)
2 changes: 2 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/ArynWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sycamore.connectors.base_writer import BaseDBWriter
from sycamore.data import Document
from sycamore.decorators import experimental


@dataclass
Expand Down Expand Up @@ -67,6 +68,7 @@ def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams")
pass


@experimental
class ArynWriter(BaseDBWriter):
Client = ArynWriterClient
Record = ArynWriterRecord
Expand Down
52 changes: 52 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
from typing import Any

import requests

from sycamore.decorators import experimental

logger = logging.getLogger(__name__)


@experimental
class ArynClient:
def __init__(self, aryn_url: str, api_key: str):
self.aryn_url = aryn_url
self.api_key = api_key

def list_docs(self, docset_id: str) -> list[str]:
try:
response = requests.get(
f"{self.aryn_url}/docsets/{docset_id}/docs", headers={"Authorization": f"Bearer {self.api_key}"}
)
items = response.json()["items"]
return [item["doc_id"] for item in items]
except Exception as e:
raise ValueError(f"Error listing docs: {e}")

def get_doc(self, docset_id: str, doc_id: str) -> dict[str, Any]:
try:
response = requests.get(
f"{self.aryn_url}/docsets/{docset_id}/docs/{doc_id}",
headers={"Authorization": f"Bearer {self.api_key}"},
)
if response.status_code != 200:
raise ValueError(
f"Error getting doc {doc_id}, received {response.status_code} {response.text} {response.reason}"
)
doc = response.json()
if doc is None:
raise ValueError(f"Received None for doc {doc_id}")
logger.debug(f"Got doc {doc}")
return doc
except Exception as e:
raise ValueError(f"Error getting doc {doc_id}: {e}")

def create_docset(self, name: str) -> str:
try:
response = requests.post(
f"{self.aryn_url}/docsets", json={"name": name}, headers={"Authorization": f"Bearer {self.api_key}"}
)
return response.json()["docset_id"]
except Exception as e:
raise ValueError(f"Error creating docset: {e}")
15 changes: 15 additions & 0 deletions lib/sycamore/sycamore/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import warnings


def experimental(cls):
"""
Decorator to mark a class as experimental.
"""

def wrapper(*args, **kwargs):
warnings.warn(
f"Class {cls.__name__} is experimental and may change in the future.", FutureWarning, stacklevel=2
)
return cls(*args, **kwargs)

return wrapper
2 changes: 2 additions & 0 deletions lib/sycamore/sycamore/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sycamore.connectors.doc_reconstruct import DocumentReconstructor
from sycamore.context import context_params
from sycamore.decorators import experimental
from sycamore.plan_nodes import Node
from sycamore import Context, DocSet
from sycamore.data import Document
Expand Down Expand Up @@ -634,6 +635,7 @@ def qdrant(self, client_params: dict, query_params: dict, **kwargs) -> DocSet:
)
return DocSet(self._context, wr)

@experimental
def aryn(
self, docset_id: str, aryn_api_key: Optional[str] = None, aryn_url: Optional[str] = None, **kwargs
) -> DocSet:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

import pytest

from sycamore.connectors.aryn.client import ArynClient


aryn_endpoint = os.getenv("ARYN_ENDPOINT")


@pytest.mark.skip(reason="For manual testing only")
def test_list_docs():
aryn_api_key = os.getenv("ARYN_TEST_API_KEY")
client = ArynClient(aryn_url=f"{aryn_endpoint}", api_key=aryn_api_key)
docset_id = ""
docs = client.list_docs(docset_id)
for doc in docs:
print(doc)


@pytest.mark.skip(reason="For manual testing only")
def test_get_doc():
aryn_api_key = os.getenv("ARYN_TEST_API_KEY")
client = ArynClient(aryn_url=f"{aryn_endpoint}", api_key=aryn_api_key)
docset_id = ""
docs = client.list_docs(docset_id)
for doc in docs:
print(doc)
doc = client.get_doc(docset_id, doc)
print(doc)
18 changes: 11 additions & 7 deletions lib/sycamore/sycamore/writer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from typing import Any, Callable, Optional, Union, TYPE_CHECKING

import requests
from pyarrow.fs import FileSystem

from sycamore.connectors.aryn.client import ArynClient
from sycamore.context import Context, ExecMode, context_params
from sycamore.connectors.common import HostAndPort
from sycamore.connectors.file.file_writer import default_doc_to_bytes, default_filename, FileWriter, JsonWriter
from sycamore.data import Document
from sycamore.decorators import experimental
from sycamore.executor import Execution
from sycamore.plan_nodes import Node
from sycamore.docset import DocSet
Expand Down Expand Up @@ -543,6 +544,7 @@ def elasticsearch(
)
return self._maybe_execute(es_docs, execute)

@experimental
@requires_modules("neo4j", extra="neo4j")
def neo4j(
self,
Expand Down Expand Up @@ -811,6 +813,7 @@ def json(

self._maybe_execute(node, True)

@experimental
def aryn(
self,
docset_id: Optional[str] = None,
Expand All @@ -824,8 +827,6 @@ def aryn(
Args:
docset_id: The id of the docset to write to. If not provided, a new docset will be created.
create_new_docset: If true, a new docset will be created. If false, the docset with the provided
id will be used.
name: The name of the new docset to create. Required if create_new_docset is true.
aryn_api_key: The api key to use for authentication. If not provided, the api key from the config
file will be used.
Expand All @@ -848,10 +849,13 @@ def aryn(
raise ValueError("Either docset_id or name must be provided")

if docset_id is None and name is not None:
headers = {"Authorization": f"Bearer {aryn_api_key}"}
res = requests.post(url=f"{aryn_url}/docsets", data={"name": name}, headers=headers)
docset_id = res.json()["docset_id"]

try:
aryn_client = ArynClient(aryn_url, aryn_api_key)
docset_id = aryn_client.create_docset(name)
logger.info(f"Created new docset with id {docset_id} and name {name}")
except Exception as e:
logger.error(f"Error creating new docset: {e}")
raise e
client_params = ArynWriterClientParams(aryn_url, aryn_api_key)
target_params = ArynWriterTargetParams(docset_id)
ds = ArynWriter(self.plan, client_params=client_params, target_params=target_params, **kwargs)
Expand Down

0 comments on commit 2227815

Please sign in to comment.