mirror of https://github.com/hwchase17/langchain
Add Runnable.get_graph() to get a graph representation of a Runnable (#15040)
It can be drawn in ascii with Runnable.get_graph().draw()pull/15080/head
parent
aad3d8bd47
commit
7d5800ee51
@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, NamedTuple, Optional, Type, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables.graph_draw import draw
|
||||
|
||||
|
||||
class Edge(NamedTuple):
|
||||
source: str
|
||||
target: str
|
||||
|
||||
|
||||
class Node(NamedTuple):
|
||||
id: str
|
||||
data: Union[Type[BaseModel], Runnable]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.nodes)
|
||||
|
||||
def next_id(self) -> str:
|
||||
return uuid4().hex
|
||||
|
||||
def add_node(self, data: Union[Type[BaseModel], Runnable]) -> Node:
|
||||
"""Add a node to the graph and return it."""
|
||||
node = Node(id=self.next_id(), data=data)
|
||||
self.nodes[node.id] = node
|
||||
return node
|
||||
|
||||
def remove_node(self, node: Node) -> None:
|
||||
"""Remove a node from the graphm and all edges connected to it."""
|
||||
self.nodes.pop(node.id)
|
||||
self.edges = [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source != node.id and edge.target != node.id
|
||||
]
|
||||
|
||||
def add_edge(self, source: Node, target: Node) -> Edge:
|
||||
"""Add an edge to the graph and return it."""
|
||||
if source.id not in self.nodes:
|
||||
raise ValueError(f"Source node {source.id} not in graph")
|
||||
if target.id not in self.nodes:
|
||||
raise ValueError(f"Target node {target.id} not in graph")
|
||||
edge = Edge(source=source.id, target=target.id)
|
||||
self.edges.append(edge)
|
||||
return edge
|
||||
|
||||
def extend(self, graph: Graph) -> None:
|
||||
"""Add all nodes and edges from another graph.
|
||||
Note this doesn't check for duplicates, nor does it connect the graphs."""
|
||||
self.nodes.update(graph.nodes)
|
||||
self.edges.extend(graph.edges)
|
||||
|
||||
def first_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a target of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph this node would be the origin."""
|
||||
targets = {edge.target for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
for node in self.nodes.values():
|
||||
if node.id not in targets:
|
||||
found.append(node)
|
||||
return found[0] if len(found) == 1 else None
|
||||
|
||||
def last_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a source of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph this node would be the destination.
|
||||
"""
|
||||
sources = {edge.source for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
for node in self.nodes.values():
|
||||
if node.id not in sources:
|
||||
found.append(node)
|
||||
return found[0] if len(found) == 1 else None
|
||||
|
||||
def trim_first_node(self) -> None:
|
||||
"""Remove the first node if it exists and has a single outgoing edge,
|
||||
ie. if removing it would not leave the graph without a "first" node."""
|
||||
first_node = self.first_node()
|
||||
if first_node:
|
||||
if (
|
||||
len(self.nodes) == 1
|
||||
or len([edge for edge in self.edges if edge.source == first_node.id])
|
||||
== 1
|
||||
):
|
||||
self.remove_node(first_node)
|
||||
|
||||
def trim_last_node(self) -> None:
|
||||
"""Remove the last node if it exists and has a single incoming edge,
|
||||
ie. if removing it would not leave the graph without a "last" node."""
|
||||
last_node = self.last_node()
|
||||
if last_node:
|
||||
if (
|
||||
len(self.nodes) == 1
|
||||
or len([edge for edge in self.edges if edge.target == last_node.id])
|
||||
== 1
|
||||
):
|
||||
self.remove_node(last_node)
|
||||
|
||||
def draw_ascii(self) -> str:
|
||||
def node_data(node: Node) -> str:
|
||||
if isinstance(node.data, Runnable):
|
||||
try:
|
||||
data = str(node.data)
|
||||
if (
|
||||
data.startswith("<")
|
||||
or data[0] != data[0].upper()
|
||||
or len(data.splitlines()) > 1
|
||||
):
|
||||
data = node.data.__class__.__name__
|
||||
elif len(data) > 36:
|
||||
data = data[:36] + "..."
|
||||
except Exception:
|
||||
data = node.data.__class__.__name__
|
||||
else:
|
||||
data = node.data.__name__
|
||||
return data
|
||||
|
||||
return draw(
|
||||
{node.id: node_data(node) for node in self.nodes.values()},
|
||||
[(edge.source, edge.target) for edge in self.edges],
|
||||
)
|
@ -0,0 +1,304 @@
|
||||
"""Draws DAG in ASCII.
|
||||
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Mapping, Sequence, Tuple
|
||||
|
||||
|
||||
class VertexViewer:
|
||||
"""Class to define vertex box boundaries that will be accounted for during
|
||||
graph building by grandalf.
|
||||
|
||||
Args:
|
||||
name (str): name of the vertex.
|
||||
"""
|
||||
|
||||
HEIGHT = 3 # top and bottom box edges + text
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self._h = self.HEIGHT # top and bottom box edges + text
|
||||
self._w = len(name) + 2 # right and left bottom edges + text
|
||||
|
||||
@property
|
||||
def h(self) -> int:
|
||||
"""Height of the box."""
|
||||
return self._h
|
||||
|
||||
@property
|
||||
def w(self) -> int:
|
||||
"""Width of the box."""
|
||||
return self._w
|
||||
|
||||
|
||||
class AsciiCanvas:
|
||||
"""Class for drawing in ASCII.
|
||||
|
||||
Args:
|
||||
cols (int): number of columns in the canvas. Should be > 1.
|
||||
lines (int): number of lines in the canvas. Should be > 1.
|
||||
"""
|
||||
|
||||
TIMEOUT = 10
|
||||
|
||||
def __init__(self, cols: int, lines: int) -> None:
|
||||
assert cols > 1
|
||||
assert lines > 1
|
||||
|
||||
self.cols = cols
|
||||
self.lines = lines
|
||||
|
||||
self.canvas = [[" "] * cols for line in range(lines)]
|
||||
|
||||
def draw(self) -> str:
|
||||
"""Draws ASCII canvas on the screen."""
|
||||
lines = map("".join, self.canvas)
|
||||
return os.linesep.join(lines)
|
||||
|
||||
def point(self, x: int, y: int, char: str) -> None:
|
||||
"""Create a point on ASCII canvas.
|
||||
|
||||
Args:
|
||||
x (int): x coordinate. Should be >= 0 and < number of columns in
|
||||
the canvas.
|
||||
y (int): y coordinate. Should be >= 0 an < number of lines in the
|
||||
canvas.
|
||||
char (str): character to place in the specified point on the
|
||||
canvas.
|
||||
"""
|
||||
assert len(char) == 1
|
||||
assert x >= 0
|
||||
assert x < self.cols
|
||||
assert y >= 0
|
||||
assert y < self.lines
|
||||
|
||||
self.canvas[y][x] = char
|
||||
|
||||
def line(self, x0: int, y0: int, x1: int, y1: int, char: str) -> None:
|
||||
"""Create a line on ASCII canvas.
|
||||
|
||||
Args:
|
||||
x0 (int): x coordinate where the line should start.
|
||||
y0 (int): y coordinate where the line should start.
|
||||
x1 (int): x coordinate where the line should end.
|
||||
y1 (int): y coordinate where the line should end.
|
||||
char (str): character to draw the line with.
|
||||
"""
|
||||
if x0 > x1:
|
||||
x1, x0 = x0, x1
|
||||
y1, y0 = y0, y1
|
||||
|
||||
dx = x1 - x0
|
||||
dy = y1 - y0
|
||||
|
||||
if dx == 0 and dy == 0:
|
||||
self.point(x0, y0, char)
|
||||
elif abs(dx) >= abs(dy):
|
||||
for x in range(x0, x1 + 1):
|
||||
if dx == 0:
|
||||
y = y0
|
||||
else:
|
||||
y = y0 + int(round((x - x0) * dy / float(dx)))
|
||||
self.point(x, y, char)
|
||||
elif y0 < y1:
|
||||
for y in range(y0, y1 + 1):
|
||||
if dy == 0:
|
||||
x = x0
|
||||
else:
|
||||
x = x0 + int(round((y - y0) * dx / float(dy)))
|
||||
self.point(x, y, char)
|
||||
else:
|
||||
for y in range(y1, y0 + 1):
|
||||
if dy == 0:
|
||||
x = x0
|
||||
else:
|
||||
x = x1 + int(round((y - y1) * dx / float(dy)))
|
||||
self.point(x, y, char)
|
||||
|
||||
def text(self, x: int, y: int, text: str) -> None:
|
||||
"""Print a text on ASCII canvas.
|
||||
|
||||
Args:
|
||||
x (int): x coordinate where the text should start.
|
||||
y (int): y coordinate where the text should start.
|
||||
text (str): string that should be printed.
|
||||
"""
|
||||
for i, char in enumerate(text):
|
||||
self.point(x + i, y, char)
|
||||
|
||||
def box(self, x0: int, y0: int, width: int, height: int) -> None:
|
||||
"""Create a box on ASCII canvas.
|
||||
|
||||
Args:
|
||||
x0 (int): x coordinate of the box corner.
|
||||
y0 (int): y coordinate of the box corner.
|
||||
width (int): box width.
|
||||
height (int): box height.
|
||||
"""
|
||||
assert width > 1
|
||||
assert height > 1
|
||||
|
||||
width -= 1
|
||||
height -= 1
|
||||
|
||||
for x in range(x0, x0 + width):
|
||||
self.point(x, y0, "-")
|
||||
self.point(x, y0 + height, "-")
|
||||
|
||||
for y in range(y0, y0 + height):
|
||||
self.point(x0, y, "|")
|
||||
self.point(x0 + width, y, "|")
|
||||
|
||||
self.point(x0, y0, "+")
|
||||
self.point(x0 + width, y0, "+")
|
||||
self.point(x0, y0 + height, "+")
|
||||
self.point(x0 + width, y0 + height, "+")
|
||||
|
||||
|
||||
def _build_sugiyama_layout(
|
||||
vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]
|
||||
) -> Any:
|
||||
try:
|
||||
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import]
|
||||
from grandalf.layouts import SugiyamaLayout # type: ignore[import]
|
||||
from grandalf.routing import ( # type: ignore[import]
|
||||
EdgeViewer,
|
||||
route_with_lines,
|
||||
)
|
||||
except ImportError:
|
||||
print("Install grandalf to draw graphs. `pip install grandalf`")
|
||||
raise
|
||||
#
|
||||
# Just a reminder about naming conventions:
|
||||
# +------------X
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# Y
|
||||
#
|
||||
|
||||
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()}
|
||||
edges_ = [Edge(vertices_[s], vertices_[e]) for s, e in edges]
|
||||
vertices_list = vertices_.values()
|
||||
graph = Graph(vertices_list, edges_)
|
||||
|
||||
for vertex in vertices_list:
|
||||
vertex.view = VertexViewer(vertex.data)
|
||||
|
||||
# NOTE: determine min box length to create the best layout
|
||||
minw = min(v.view.w for v in vertices_list)
|
||||
|
||||
for edge in edges_:
|
||||
edge.view = EdgeViewer()
|
||||
|
||||
sug = SugiyamaLayout(graph.C[0])
|
||||
graph = graph.C[0]
|
||||
roots = list(filter(lambda x: len(x.e_in()) == 0, graph.sV))
|
||||
|
||||
sug.init_all(roots=roots, optimize=True)
|
||||
|
||||
sug.yspace = VertexViewer.HEIGHT
|
||||
sug.xspace = minw
|
||||
sug.route_edge = route_with_lines
|
||||
|
||||
sug.draw()
|
||||
|
||||
return sug
|
||||
|
||||
|
||||
def draw(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str:
|
||||
"""Build a DAG and draw it in ASCII.
|
||||
|
||||
Args:
|
||||
vertices (list): list of graph vertices.
|
||||
edges (list): list of graph edges.
|
||||
|
||||
Returns:
|
||||
str: ASCII representation
|
||||
|
||||
Example:
|
||||
>>> from dvc.dagascii import draw
|
||||
>>> vertices = [1, 2, 3, 4]
|
||||
>>> edges = [(1, 2), (2, 3), (2, 4), (1, 4)]
|
||||
>>> print(draw(vertices, edges))
|
||||
+---+ +---+
|
||||
| 3 | | 4 |
|
||||
+---+ *+---+
|
||||
* ** *
|
||||
* ** *
|
||||
* * *
|
||||
+---+ *
|
||||
| 2 | *
|
||||
+---+ *
|
||||
* *
|
||||
* *
|
||||
**
|
||||
+---+
|
||||
| 1 |
|
||||
+---+
|
||||
"""
|
||||
|
||||
# NOTE: coordinates might me negative, so we need to shift
|
||||
# everything to the positive plane before we actually draw it.
|
||||
Xs = [] # noqa: N806
|
||||
Ys = [] # noqa: N806
|
||||
|
||||
sug = _build_sugiyama_layout(vertices, edges)
|
||||
|
||||
for vertex in sug.g.sV:
|
||||
# NOTE: moving boxes w/2 to the left
|
||||
Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0)
|
||||
Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0)
|
||||
Ys.append(vertex.view.xy[1])
|
||||
Ys.append(vertex.view.xy[1] + vertex.view.h)
|
||||
|
||||
for edge in sug.g.sE:
|
||||
for x, y in edge.view._pts:
|
||||
Xs.append(x)
|
||||
Ys.append(y)
|
||||
|
||||
minx = min(Xs)
|
||||
miny = min(Ys)
|
||||
maxx = max(Xs)
|
||||
maxy = max(Ys)
|
||||
|
||||
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
|
||||
canvas_lines = int(round(maxy - miny))
|
||||
|
||||
canvas = AsciiCanvas(canvas_cols, canvas_lines)
|
||||
|
||||
# NOTE: first draw edges so that node boxes could overwrite them
|
||||
for edge in sug.g.sE:
|
||||
assert len(edge.view._pts) > 1
|
||||
for index in range(1, len(edge.view._pts)):
|
||||
start = edge.view._pts[index - 1]
|
||||
end = edge.view._pts[index]
|
||||
|
||||
start_x = int(round(start[0] - minx))
|
||||
start_y = int(round(start[1] - miny))
|
||||
end_x = int(round(end[0] - minx))
|
||||
end_y = int(round(end[1] - miny))
|
||||
|
||||
assert start_x >= 0
|
||||
assert start_y >= 0
|
||||
assert end_x >= 0
|
||||
assert end_y >= 0
|
||||
|
||||
canvas.line(start_x, start_y, end_x, end_y, "*")
|
||||
|
||||
for vertex in sug.g.sV:
|
||||
# NOTE: moving boxes w/2 to the left
|
||||
x = vertex.view.xy[0] - vertex.view.w / 2.0
|
||||
y = vertex.view.xy[1]
|
||||
|
||||
canvas.box(
|
||||
int(round(x - minx)),
|
||||
int(round(y - miny)),
|
||||
vertex.view.w,
|
||||
vertex.view.h,
|
||||
)
|
||||
|
||||
canvas.text(int(round(x - minx)) + 1, int(round(y - miny)) + 1, vertex.data)
|
||||
|
||||
return canvas.draw()
|
@ -0,0 +1,88 @@
|
||||
# serializer version: 1
|
||||
# name: test_graph_sequence
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+--------------------------------+
|
||||
| CommaSeparatedListOutputParser |
|
||||
+--------------------------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+--------------------------------------+
|
||||
| CommaSeparatedListOutputParserOutput |
|
||||
+--------------------------------------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_sequence_map
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------------+
|
||||
| RunnableParallelInput |
|
||||
+-----------------------+
|
||||
**** ***
|
||||
**** ****
|
||||
** **
|
||||
+---------------------+ +--------------------------------+
|
||||
| RunnablePassthrough | | CommaSeparatedListOutputParser |
|
||||
+---------------------+ +--------------------------------+
|
||||
**** ***
|
||||
**** ****
|
||||
** **
|
||||
+------------------------+
|
||||
| RunnableParallelOutput |
|
||||
+------------------------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_single_runnable
|
||||
'''
|
||||
+----------------------+
|
||||
| StrOutputParserInput |
|
||||
+----------------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------+
|
||||
| StrOutputParser |
|
||||
+-----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------------+
|
||||
| StrOutputParserOutput |
|
||||
+-----------------------+
|
||||
'''
|
||||
# ---
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,51 @@
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
runnable = StrOutputParser()
|
||||
graph = StrOutputParser().get_graph()
|
||||
first_node = graph.first_node()
|
||||
assert first_node is not None
|
||||
assert first_node.data.schema() == runnable.input_schema.schema() # type: ignore[union-attr]
|
||||
last_node = graph.last_node()
|
||||
assert last_node is not None
|
||||
assert last_node.data.schema() == runnable.output_schema.schema() # type: ignore[union-attr]
|
||||
assert len(graph.nodes) == 3
|
||||
assert len(graph.edges) == 2
|
||||
assert graph.edges[0].source == first_node.id
|
||||
assert graph.edges[1].target == last_node.id
|
||||
assert graph.draw_ascii() == snapshot
|
||||
|
||||
|
||||
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
sequence = prompt | fake_llm | list_parser
|
||||
graph = sequence.get_graph()
|
||||
assert graph.draw_ascii() == snapshot
|
||||
|
||||
|
||||
def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
sequence: Runnable = (
|
||||
prompt
|
||||
| fake_llm
|
||||
| {
|
||||
"original": RunnablePassthrough(input_type=str),
|
||||
"as_list": list_parser,
|
||||
}
|
||||
)
|
||||
graph = sequence.get_graph()
|
||||
assert graph.draw_ascii() == snapshot
|
Loading…
Reference in New Issue