Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing serialization of the graph #932

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python-sdk/indexify/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def signal_handler(sig, frame):

@app.command(help="Build image for function names")
def build_image(
workflow_file_path: str, func_names: List[str], python_sdk_path: Optional[str] = None
workflow_file_path: str,
func_names: List[str],
python_sdk_path: Optional[str] = None,
):
globals_dict = {}

Expand Down
11 changes: 9 additions & 2 deletions python-sdk/indexify/executor/function_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import concurrent.futures


class FunctionRunException(Exception):
def __init__(
self, exception: Exception, stdout: str, stderr: str, is_reducer: bool
Expand Down Expand Up @@ -131,22 +132,28 @@ def _run_function(
fn_output = None
has_failed = False
exception_msg = None
print(f"[bold] function_worker: [/bold] invoking function {fn_name} in graph {graph_name}")
print(
f"[bold] function_worker: [/bold] invoking function {fn_name} in graph {graph_name}"
)
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
try:
key = f"{namespace}/{graph_name}/{version}/{fn_name}"
if key not in function_wrapper_map:
_load_function(namespace, graph_name, fn_name, code_path, version)

fn = function_wrapper_map[key]
if str(type(fn.indexify_function)) == "<class 'indexify.functions_sdk.indexify_functions.IndexifyRo'>":
if (
str(type(fn.indexify_function))
== "<class 'indexify.functions_sdk.indexify_functions.IndexifyRo'>"
):
router_output = fn.invoke_router(fn_name, input)
else:
fn_output = fn.invoke_fn_ser(fn_name, input, init_value)

is_reducer = fn.indexify_function.accumulate is not None
except Exception as e:
import sys

print(traceback.format_exc(), file=sys.stderr)
has_failed = True
exception_msg = str(e)
Expand Down
7 changes: 5 additions & 2 deletions python-sdk/indexify/functions_sdk/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ def route(
self.routers[from_node.name].append(node.name)
return self

def serialize(self):
def serialize(self, additional_modules):
# Get all unique modules from nodes and edges
pickled_functions = {}
for module in additional_modules:
cloudpickle.register_pickle_by_value(module)
for node in self.nodes.values():
cloudpickle.register_pickle_by_value(sys.modules[node.__module__])
pickled_functions[node.name] = cloudpickle.dumps(node)
cloudpickle.unregister_pickle_by_value(sys.modules[node.__module__])
if not sys.modules[node.__module__] in additional_modules:
cloudpickle.unregister_pickle_by_value(sys.modules[node.__module__])
return pickled_functions

def add_edge(
Expand Down
17 changes: 10 additions & 7 deletions python-sdk/indexify/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.close()

def register_compute_graph(self, graph: Graph):
def register_compute_graph(self, graph: Graph, additional_modules):
graph_metadata = graph.definition()
serialized_code = cloudpickle.dumps(graph.serialize())
serialized_code = cloudpickle.dumps(graph.serialize(additional_modules))
response = self._post(
f"namespaces/{self.namespace}/compute_graphs",
files={"code": serialized_code},
Expand Down Expand Up @@ -197,9 +197,11 @@ def namespaces(self) -> List[str]:
for item in namespaces_dict:
namespaces.append(item["name"])
return namespaces

@classmethod
def new_namespace(cls, namespace: str, server_addr: Optional[str] = "http://localhost:8900"):
def new_namespace(
cls, namespace: str, server_addr: Optional[str] = "http://localhost:8900"
):
# Create a new client instance with the specified server address
client = cls(service_url=server_addr)

Expand All @@ -212,11 +214,10 @@ def new_namespace(cls, namespace: str, server_addr: Optional[str] = "http://loca

# Set the namespace for the newly created client
client.namespace = namespace

# Return the client instance with the new namespace
return client


def create_namespace(self, namespace: str):
self._post("namespaces", json={"name": namespace})

Expand Down Expand Up @@ -259,7 +260,9 @@ def invoke_graph_with_object(
return v["id"]
if k == "DiagnosticMessage":
message = v.get("message", None)
print(f"[bold red]scheduler diagnostic: [/bold red]{message}")
print(
f"[bold red]scheduler diagnostic: [/bold red]{message}"
)
continue
event_payload = InvocationEventPayload.model_validate(v)
event = InvocationEvent(event_name=k, payload=event_payload)
Expand Down
14 changes: 8 additions & 6 deletions python-sdk/indexify/remote_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def foo(x: int) -> int:
return self._client.invoke_graph_with_object(
self._name, block_until_done, **kwargs
)

def rerun(self):
"""
Rerun the graph with the given invocation ID.
Expand All @@ -41,17 +41,19 @@ def rerun(self):
self._client.rerun_graph(self._name)

@classmethod
def deploy(cls, g: Graph, additional_modules=[], server_url: Optional[str] = "http://localhost:8900"):
def deploy(
cls,
g: Graph,
additional_modules=[],
server_url: Optional[str] = "http://localhost:8900",
):
"""
Create a new RemoteGraph from a local Graph object.
:param g: The local Graph object.
:param server_url: The URL of the server where the graph will be registered.
"""
import cloudpickle
for module in additional_modules:
cloudpickle.register_pickle_by_value(module)
client = IndexifyClient(service_url=server_url)
client.register_compute_graph(g)
client.register_compute_graph(g, additional_modules)
return cls(name=g.name, server_url=server_url)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "indexify"
version = "0.2.10"
version = "0.2.11"
description = "Python Client for Indexify"
authors = ["Tensorlake Inc. <[email protected]>"]
license = "Apache 2.0"
Expand Down
24 changes: 13 additions & 11 deletions python-sdk/tests/test_function_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import unittest
from typing import List, Mapping, Union

from pydantic import BaseModel
import msgpack
import cloudpickle
from pydantic import BaseModel

from indexify import Graph
from indexify.executor.function_worker import FunctionWorker
from indexify.functions_sdk.data_objects import File, IndexifyData
from indexify.functions_sdk.indexify_functions import indexify_function, IndexifyFunctionWrapper
from indexify.functions_sdk.indexify_functions import (
IndexifyFunctionWrapper,
indexify_function,
)


@indexify_function()
Expand Down Expand Up @@ -112,15 +114,15 @@ async def test_function_worker_extractor_raises_error(self):
temp_file_path = temp_file.name

result = await self.function_worker.async_submit(
namespace="test",
graph_name="test",
fn_name="extractor_exception",
input=IndexifyData(id="123", payload=cloudpickle.dumps(10)),
code_path=temp_file_path,
version=1,
)
namespace="test",
graph_name="test",
fn_name="extractor_exception",
input=IndexifyData(id="123", payload=cloudpickle.dumps(10)),
code_path=temp_file_path,
version=1,
)
assert not result.success
assert(result.exception == "this extractor throws an exception.")
assert result.exception == "this extractor throws an exception."


if __name__ == "__main__":
Expand Down
Loading