New syntax: adding test cases and syntactic sugar tools in graph api for merges and forks (related to #323 and #324, allows #328).
This commit is contained in:
@ -37,11 +37,20 @@ class GraphCursor:
|
||||
)
|
||||
|
||||
if len(nodes):
|
||||
chain = self.graph.add_chain(*nodes, _input=self.last)
|
||||
chain = self.graph.add_chain(*nodes, _input=self.last, use_existing_nodes=True)
|
||||
return GraphCursor(chain.graph, first=self.first, last=chain.output)
|
||||
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.graph == other.graph and self.first == other.first and self.last == other.last
|
||||
|
||||
|
||||
class PartialGraph:
|
||||
def __init__(self, *nodes):
|
||||
@ -73,6 +82,15 @@ class Graph:
|
||||
def __getitem__(self, key):
|
||||
return self.nodes[key]
|
||||
|
||||
def __enter__(self):
|
||||
return self.get_cursor().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
def __rshift__(self, other):
|
||||
return self.get_cursor().__rshift__(other)
|
||||
|
||||
def get_cursor(self, ref=BEGIN):
|
||||
return GraphCursor(self, last=self.index_of(ref))
|
||||
|
||||
@ -96,6 +114,9 @@ class Graph:
|
||||
|
||||
raise ValueError("Cannot find node matching {!r}.".format(mixed))
|
||||
|
||||
def indexes_of(self, *things):
|
||||
return set(map(self.index_of, things))
|
||||
|
||||
def outputs_of(self, idx_or_node, create=False):
|
||||
"""Get a set of the outputs for a given node, node index or name.
|
||||
"""
|
||||
@ -105,13 +126,13 @@ class Graph:
|
||||
self.edges[idx_or_node] = set()
|
||||
return self.edges[idx_or_node]
|
||||
|
||||
def add_node(self, c, *, _name=None):
|
||||
def add_node(self, new_node, *, _name=None):
|
||||
"""Add a node without connections in this graph and returns its index.
|
||||
If _name is specified, name this node (string reference for further usage).
|
||||
"""
|
||||
idx = len(self.nodes)
|
||||
self.edges[idx] = set()
|
||||
self.nodes.append(c)
|
||||
self.nodes.append(new_node)
|
||||
|
||||
if _name:
|
||||
if _name in self.named:
|
||||
@ -120,7 +141,14 @@ class Graph:
|
||||
|
||||
return idx
|
||||
|
||||
def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None):
|
||||
def get_or_add_node(self, new_node, *, _name=None):
|
||||
if new_node in self.nodes:
|
||||
if _name is not None:
|
||||
raise RuntimeError("Cannot name a node that is already present in the graph.")
|
||||
return self.index_of(new_node)
|
||||
return self.add_node(new_node, _name=_name)
|
||||
|
||||
def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None, use_existing_nodes=False):
|
||||
"""Add `nodes` as a chain in this graph.
|
||||
|
||||
**Input rules**
|
||||
@ -153,6 +181,8 @@ class Graph:
|
||||
_first = None
|
||||
_last = None
|
||||
|
||||
get_node = self.get_or_add_node if use_existing_nodes else self.add_node
|
||||
|
||||
# Sanity checks.
|
||||
if not len(nodes):
|
||||
if _input is None or _output is None:
|
||||
@ -164,7 +194,7 @@ class Graph:
|
||||
raise RuntimeError("Using add_chain(...) without nodes does not allow to use the _name parameter.")
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
_last = self.add_node(node, _name=_name if not i else None)
|
||||
_last = get_node(node, _name=_name if not i else None)
|
||||
|
||||
if _first is None:
|
||||
_first = _last
|
||||
|
||||
129
tests/structs/test_graphs_new_syntax.py
Normal file
129
tests/structs/test_graphs_new_syntax.py
Normal file
@ -0,0 +1,129 @@
|
||||
from operator import attrgetter
|
||||
from unittest.mock import sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
from bonobo.constants import BEGIN
|
||||
from bonobo.structs.graphs import Graph, GraphCursor
|
||||
from bonobo.util import tuplize
|
||||
|
||||
|
||||
@tuplize
|
||||
def get_pseudo_nodes(*names):
|
||||
for name in names:
|
||||
yield getattr(sentinel, name)
|
||||
|
||||
|
||||
def test_get_cursor():
|
||||
g = Graph()
|
||||
cursor = g.get_cursor()
|
||||
|
||||
assert cursor.graph is g
|
||||
assert cursor.first is BEGIN
|
||||
assert cursor.last is BEGIN
|
||||
|
||||
|
||||
def test_get_cursor_in_a_vacuum():
|
||||
g = Graph()
|
||||
cursor = g.get_cursor(None)
|
||||
|
||||
assert cursor.graph is g
|
||||
assert cursor.first is None
|
||||
assert cursor.last is None
|
||||
|
||||
|
||||
def test_cursor_usage_to_add_a_chain():
|
||||
a, b, c = get_pseudo_nodes(*"abc")
|
||||
|
||||
g = Graph()
|
||||
|
||||
g.get_cursor() >> a >> b >> c
|
||||
|
||||
assert len(g) == 3
|
||||
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||
assert g.outputs_of(a) == {g.index_of(b)}
|
||||
assert g.outputs_of(b) == {g.index_of(c)}
|
||||
assert g.outputs_of(c) == set()
|
||||
|
||||
|
||||
def test_cursor_usage_to_add_a_chain_in_a_context_manager():
|
||||
a, b, c = get_pseudo_nodes(*"abc")
|
||||
|
||||
g = Graph()
|
||||
with g as cur:
|
||||
cur >> a >> b >> c
|
||||
|
||||
assert len(g) == 3
|
||||
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||
assert g.outputs_of(a) == {g.index_of(b)}
|
||||
assert g.outputs_of(b) == {g.index_of(c)}
|
||||
assert g.outputs_of(c) == set()
|
||||
|
||||
|
||||
def test_implicit_cursor_usage():
|
||||
a, b, c = get_pseudo_nodes(*"abc")
|
||||
|
||||
g = Graph()
|
||||
g >> a >> b >> c
|
||||
|
||||
assert len(g) == 3
|
||||
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||
assert g.outputs_of(a) == {g.index_of(b)}
|
||||
assert g.outputs_of(b) == {g.index_of(c)}
|
||||
assert g.outputs_of(c) == set()
|
||||
|
||||
|
||||
def test_cursor_to_fork_a_graph():
|
||||
a, b, c, d, e = get_pseudo_nodes(*"abcde")
|
||||
|
||||
g = Graph()
|
||||
g >> a >> b >> c
|
||||
g.get_cursor(b) >> d >> e
|
||||
|
||||
assert len(g) == 5
|
||||
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||
assert g.outputs_of(a) == {g.index_of(b)}
|
||||
assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)}
|
||||
assert g.outputs_of(c) == set()
|
||||
assert g.outputs_of(d) == {g.index_of(e)}
|
||||
assert g.outputs_of(e) == set()
|
||||
|
||||
|
||||
def test_cursor_to_fork_at_the_end():
|
||||
a, b, c, d, e = get_pseudo_nodes(*"abcde")
|
||||
|
||||
g = Graph()
|
||||
c0 = g >> a >> b
|
||||
c1 = c0 >> c
|
||||
c2 = c0 >> d >> e
|
||||
|
||||
assert len(g) == 5
|
||||
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||
assert g.outputs_of(a) == {g.index_of(b)}
|
||||
assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)}
|
||||
assert g.outputs_of(c) == set()
|
||||
assert g.outputs_of(d) == {g.index_of(e)}
|
||||
assert g.outputs_of(e) == set()
|
||||
|
||||
assert c0.first == g.index_of(BEGIN)
|
||||
assert c0.last == g.index_of(b)
|
||||
assert c1.first == g.index_of(BEGIN)
|
||||
assert c1.last == g.index_of(c)
|
||||
assert c2.first == g.index_of(BEGIN)
|
||||
assert c2.last == g.index_of(e)
|
||||
|
||||
|
||||
def test_cursor_merge():
|
||||
a, b, c = get_pseudo_nodes(*"abc")
|
||||
g = Graph()
|
||||
|
||||
c1 = g >> a >> c
|
||||
c2 = g >> b >> c
|
||||
|
||||
assert len(g) == 3
|
||||
assert g.outputs_of(BEGIN) == g.indexes_of(a, b)
|
||||
assert g.outputs_of(a) == g.indexes_of(c)
|
||||
assert g.outputs_of(b) == g.indexes_of(c)
|
||||
assert g.outputs_of(c) == set()
|
||||
|
||||
assert c1 == c2
|
||||
Reference in New Issue
Block a user