Skip to content

Commit

Permalink
Exclude complex values from checkpoint metadata (#3448)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Feb 15, 2025
2 parents 405da6d + 62f004f commit c2a129c
Show file tree
Hide file tree
Showing 22 changed files with 88 additions and 670 deletions.
13 changes: 2 additions & 11 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
get_checkpoint_metadata,
)
from langgraph.checkpoint.postgres import _internal
from langgraph.checkpoint.postgres.base import BasePostgresSaver
Expand Down Expand Up @@ -317,17 +318,7 @@ def put(
checkpoint["id"],
checkpoint_id,
Jsonb(self._dump_checkpoint(copy)),
self._dump_metadata(
{
**{
k: v
for k, v in config["configurable"].items()
if not k.startswith("__")
},
**config.get("metadata", {}),
**metadata,
}
),
self._dump_metadata(get_checkpoint_metadata(config, metadata)),
),
)
return next_config
Expand Down
13 changes: 2 additions & 11 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
get_checkpoint_metadata,
)
from langgraph.checkpoint.postgres import _ainternal
from langgraph.checkpoint.postgres.base import BasePostgresSaver
Expand Down Expand Up @@ -275,17 +276,7 @@ async def aput(
checkpoint["id"],
checkpoint_id,
Jsonb(self._dump_checkpoint(copy)),
self._dump_metadata(
{
**{
k: v
for k, v in config["configurable"].items()
if not k.startswith("__")
},
**config.get("metadata", {}),
**metadata,
}
),
self._dump_metadata(get_checkpoint_metadata(config, metadata)),
),
)
return next_config
Expand Down
25 changes: 3 additions & 22 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_metadata,
)
from langgraph.checkpoint.postgres import _ainternal, _internal
from langgraph.checkpoint.postgres.base import BasePostgresSaver
Expand Down Expand Up @@ -423,17 +424,7 @@ def put(
thread_id,
checkpoint_ns,
Jsonb(self._dump_checkpoint(copy)),
self._dump_metadata(
{
**{
k: v
for k, v in config["configurable"].items()
if not k.startswith("__")
},
**config.get("metadata", {}),
**metadata,
}
),
self._dump_metadata(get_checkpoint_metadata(config, metadata)),
),
)
return next_config
Expand Down Expand Up @@ -752,17 +743,7 @@ async def aput(
thread_id,
checkpoint_ns,
Jsonb(self._dump_checkpoint(copy)),
self._dump_metadata(
{
**{
k: v
for k, v in config["configurable"].items()
if not k.startswith("__")
},
**config.get("metadata", {}),
**metadata,
}
),
self._dump_metadata(get_checkpoint_metadata(config, metadata)),
),
)
return next_config
Expand Down
31 changes: 10 additions & 21 deletions libs/checkpoint-postgres/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions libs/checkpoint-postgres/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint-postgres"
version = "2.0.14"
version = "2.0.15"
description = "Library with a Postgres implementation of LangGraph checkpoint saver."
authors = []
license = "MIT"
Expand All @@ -22,10 +22,10 @@ pytest = "^7.2.1"
anyio = "^4.4.0"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-watch = "^4.2.0"
mypy = "^1.10.0"
psycopg = {extras = ["binary"], version = ">=3.0.0"}
langgraph-checkpoint = {path = "../checkpoint", develop = true}
pytest-watcher = "^0.4.3"

[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
Expand Down Expand Up @@ -61,3 +61,9 @@ warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"

[tool.pytest-watcher]
now = true
delay = 0.1
runner_args = ["--ff", "-x", "-v", "--tb", "short"]
patterns = ["*.py"]
10 changes: 7 additions & 3 deletions libs/checkpoint-postgres/tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from psycopg_pool import AsyncConnectionPool

from langgraph.checkpoint.base import (
EXCLUDED_METADATA_KEYS,
Checkpoint,
CheckpointMetadata,
create_checkpoint,
Expand All @@ -23,6 +24,10 @@
from tests.conftest import DEFAULT_POSTGRES_URI


def _exclude_keys(config: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in config.items() if k not in EXCLUDED_METADATA_KEYS}


@asynccontextmanager
async def _pool_saver():
"""Fixture for pool mode testing."""
Expand Down Expand Up @@ -223,7 +228,6 @@ async def test_combined_metadata(saver_name: str, test_data) -> None:
assert checkpoint.metadata == {
**metadata,
"thread_id": "thread-2",
"checkpoint_ns": "",
"run_id": "my_run_id",
}

Expand Down Expand Up @@ -251,14 +255,14 @@ async def test_asearch(saver_name: str, test_data) -> None:
search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == {
**configs[0]["configurable"],
**_exclude_keys(configs[0]["configurable"]),
**metadata[0],
}

search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == {
**configs[1]["configurable"],
**_exclude_keys(configs[1]["configurable"]),
**metadata[1],
}

Expand Down
10 changes: 7 additions & 3 deletions libs/checkpoint-postgres/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from psycopg_pool import ConnectionPool

from langgraph.checkpoint.base import (
EXCLUDED_METADATA_KEYS,
Checkpoint,
CheckpointMetadata,
create_checkpoint,
Expand All @@ -21,6 +22,10 @@
from tests.conftest import DEFAULT_POSTGRES_URI


def _exclude_keys(config: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in config.items() if k not in EXCLUDED_METADATA_KEYS}


@contextmanager
def _pool_saver():
"""Fixture for pool mode testing."""
Expand Down Expand Up @@ -205,7 +210,6 @@ def test_combined_metadata(saver_name: str, test_data) -> None:
assert checkpoint.metadata == {
**metadata,
"thread_id": "thread-2",
"checkpoint_ns": "",
"run_id": "my_run_id",
}

Expand Down Expand Up @@ -233,14 +237,14 @@ def test_search(saver_name: str, test_data) -> None:
search_results_1 = list(saver.list(None, filter=query_1))
assert len(search_results_1) == 1
assert search_results_1[0].metadata == {
**configs[0]["configurable"],
**_exclude_keys(configs[0]["configurable"]),
**metadata[0],
}

search_results_2 = list(saver.list(None, filter=query_2))
assert len(search_results_2) == 1
assert search_results_2[0].metadata == {
**configs[1]["configurable"],
**_exclude_keys(configs[1]["configurable"]),
**metadata[1],
}

Expand Down
11 changes: 2 additions & 9 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CheckpointTuple,
SerializerProtocol,
get_checkpoint_id,
get_checkpoint_metadata,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import ChannelProtocol
Expand Down Expand Up @@ -398,15 +399,7 @@ def put(
checkpoint_ns = config["configurable"]["checkpoint_ns"]
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
serialized_metadata = self.jsonplus_serde.dumps(
{
**{
k: v
for k, v in config["configurable"].items()
if not k.startswith("__")
},
**config.get("metadata", {}),
**metadata,
}
get_checkpoint_metadata(config, metadata)
)
with self.cursor() as cur:
cur.execute(
Expand Down
11 changes: 2 additions & 9 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CheckpointTuple,
SerializerProtocol,
get_checkpoint_id,
get_checkpoint_metadata,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import ChannelProtocol
Expand Down Expand Up @@ -464,15 +465,7 @@ async def aput(
checkpoint_ns = config["configurable"]["checkpoint_ns"]
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
serialized_metadata = self.jsonplus_serde.dumps(
{
**{
k: v
for k, v in config["configurable"].items()
if not k.startswith("__")
},
**config.get("metadata", {}),
**metadata,
}
get_checkpoint_metadata(config, metadata)
)
async with (
self.lock,
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint-sqlite/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint-sqlite"
version = "2.0.4"
version = "2.0.5"
description = "Library with a SQLite implementation of LangGraph checkpoint saver."
authors = []
license = "MIT"
Expand Down
6 changes: 3 additions & 3 deletions libs/checkpoint-sqlite/tests/test_aiosqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ async def test_combined_metadata(self) -> None:
assert checkpoint.metadata == {
**self.metadata_2,
"thread_id": "thread-2",
"checkpoint_ns": "",
"run_id": "my_run_id",
}

Expand All @@ -94,14 +93,15 @@ async def test_asearch(self) -> None:
search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == {
**self.config_1["configurable"],
"thread_id": "thread-1",
"thread_ts": "1",
**self.metadata_1,
}

search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == {
**self.config_2["configurable"],
"thread_id": "thread-2",
**self.metadata_2,
}

Expand Down
6 changes: 3 additions & 3 deletions libs/checkpoint-sqlite/tests/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_combined_metadata(self) -> None:
assert checkpoint.metadata == {
**self.metadata_2,
"thread_id": "thread-2",
"checkpoint_ns": "",
"run_id": "my_run_id",
}

Expand All @@ -97,14 +96,15 @@ def test_search(self) -> None:
search_results_1 = list(saver.list(None, filter=query_1))
assert len(search_results_1) == 1
assert search_results_1[0].metadata == {
**self.config_1["configurable"],
"thread_id": "thread-1",
"thread_ts": "1",
**self.metadata_1,
}

search_results_2 = list(saver.list(None, filter=query_2))
assert len(search_results_2) == 1
assert search_results_2[0].metadata == {
**self.config_2["configurable"],
"thread_id": "thread-2",
**self.metadata_2,
}

Expand Down
Loading

0 comments on commit c2a129c

Please sign in to comment.