Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(graph): make add_edge function start_key parameter in order #3575

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

kevinkelin
Copy link

What problem is solved?

When a graph is compiled twice, especially when the interrupt node exists, when the graph resumes running for the second time, when the add_edge function is called, for the start_key of the list type, if the order of the list is inconsistent with the first time, the end_key will never be reachable the second time

In what scenarios will there be problems?

When using fastapi or flask as a backend service, when there is an interrupted node, there will be problems when restoring the graph operation

import json

from fastapi import FastAPI
from pydantic import BaseModel
from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
from starlette.responses import StreamingResponse
from typing_extensions import TypedDict
from typing import List, Annotated
from operator import add
from langgraph.graph import StateGraph, START, END
import motor
from motor.motor_asyncio import AsyncIOMotorClient


# define the states
class GraphState(TypedDict):
    messages: Annotated[List[str], add]


async def node1(state: GraphState):
    return {"messages": ["node1 message"]}


async def node2(state: GraphState):
    return {"messages": ["node2 message"]}


async def node_interupt(state: GraphState):
    return {"messages": ["interupt node message"]}


async def flow_end(state: GraphState):
    return {"messages": ["flow end node message"]}


class Task(BaseModel):
    task_id: str
    entry_point: str = ""


app = FastAPI()


@app.post("/run")
async def run(task: Task):
    async def stream_all_response(response, app, config):
        async for event in response:
            kind = event.get("event")
            data = event.get("data")
            name = event.get("name")
            if name == "_write":
                continue
            if kind == "on_chain_end":
                ydata = {
                    "kind": kind,
                    "name": name,
                    "data": data
                }
                yield f'event: message\nretry: 15000\ndata: {json.dumps(ydata)}\n\n'

    # add node
    builder = StateGraph(GraphState)
    builder.add_node("node1", node1)
    builder.add_node("node2", node2)
    builder.add_node("node_interupt", node_interupt)
    builder.add_node("flow_end", flow_end)
    # add edge
    builder.add_edge(START, "node1")
    builder.add_edge("node1", "node2")
    builder.add_edge(["node1", "node2"], "node_interupt")
    builder.add_edge(["node1", "node2", "node_interupt"], "flow_end")
    builder.add_edge("flow_end", END)

    # checkpointer = MemorySaver()
    checkpointer = AsyncMongoDBSaver(
        client=mongo_client, # a connect mongo pool
        db_name="agent_test",
        checkpoint_collection_name="checkpoints_demon",
        writes_collection_name="checkpoint_writes_demon",
    )

    app = builder.compile(checkpointer=checkpointer, interrupt_after=["node_interupt"])
    app.get_graph().draw_mermaid_png(output_file_path="graph2.png")
    mermaid_data = app.get_graph().draw_mermaid()
    print(mermaid_data)
    config = {
        "configurable": {
            "thread_id": task.task_id
        }
    }
    # 恢复节点
    if task.entry_point:
        pre_snapshot = None
        async for i in app.aget_state_history(config=config):
            if task.entry_point in i.next:
                config = pre_snapshot.config
                response = app.astream_events(None, config=config, version="v2")
                g = stream_all_response(response, app, config)
                return StreamingResponse(g, media_type="text/event-stream")
            else:
                pre_snapshot = i
        else:
            return {"msg": "The graph has already finished!"}

    inputs = {
        "messages": [],
        "conditon_router": "left"
    }
    response = app.astream_events(inputs, config=config, version="v2")
    g = stream_all_response(response, app, config)
    return StreamingResponse(g, media_type="text/event-stream")


if __name__ == '__main__':
    import uvicorn

    uvicorn.run("main:app", host="0.0.0.0", port=5002, reload=True)

graph2

When the /run interface is called for the first time,

{
    "task_id": "12345"
}

It will be interrupted at node_interupt,

When the /run interface is called for the second time,

{
	"task_id": "12345",
	"entry_point": "node_interupt"
}

It will resume running normally

But if the add_edge function is resumed for the second time, if the start_key order is different from the first time, such as

# first is builder.add_edge(["node1", "node2", "node_interupt"], "flow_end")

builder.add_edge(["node1", "node_interupt", "node2"], "flow_end")

At this time, the flow_end node will be unreachable, graph End the run directly

This situation usually occurs when the node data for building the graph is not fixed, but the node data is obtained from the database. When the request hits different nodes of LVS, the order of start_key is likely to be inconsistent.

Improvement method

In the add_edge method, sort the start_key so that the start_key remains consistent each time the graph is built

Copy link

vercel bot commented Feb 25, 2025

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
langgraph-docs-preview ✅ Ready (Inspect) Visit Preview 💬 Add feedback Feb 25, 2025 4:04pm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant