New syntax: adding test cases and syntactic sugar tools in graph api for merges and forks (related to #323 and #324, allows #328).
This commit is contained in:
@ -37,11 +37,20 @@ class GraphCursor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(nodes):
|
if len(nodes):
|
||||||
chain = self.graph.add_chain(*nodes, _input=self.last)
|
chain = self.graph.add_chain(*nodes, _input=self.last, use_existing_nodes=True)
|
||||||
return GraphCursor(chain.graph, first=self.first, last=chain.output)
|
return GraphCursor(chain.graph, first=self.first, last=chain.output)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.graph == other.graph and self.first == other.first and self.last == other.last
|
||||||
|
|
||||||
|
|
||||||
class PartialGraph:
|
class PartialGraph:
|
||||||
def __init__(self, *nodes):
|
def __init__(self, *nodes):
|
||||||
@ -73,6 +82,15 @@ class Graph:
|
|||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.nodes[key]
|
return self.nodes[key]
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.get_cursor().__enter__()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __rshift__(self, other):
|
||||||
|
return self.get_cursor().__rshift__(other)
|
||||||
|
|
||||||
def get_cursor(self, ref=BEGIN):
|
def get_cursor(self, ref=BEGIN):
|
||||||
return GraphCursor(self, last=self.index_of(ref))
|
return GraphCursor(self, last=self.index_of(ref))
|
||||||
|
|
||||||
@ -96,6 +114,9 @@ class Graph:
|
|||||||
|
|
||||||
raise ValueError("Cannot find node matching {!r}.".format(mixed))
|
raise ValueError("Cannot find node matching {!r}.".format(mixed))
|
||||||
|
|
||||||
|
def indexes_of(self, *things):
|
||||||
|
return set(map(self.index_of, things))
|
||||||
|
|
||||||
def outputs_of(self, idx_or_node, create=False):
|
def outputs_of(self, idx_or_node, create=False):
|
||||||
"""Get a set of the outputs for a given node, node index or name.
|
"""Get a set of the outputs for a given node, node index or name.
|
||||||
"""
|
"""
|
||||||
@ -105,13 +126,13 @@ class Graph:
|
|||||||
self.edges[idx_or_node] = set()
|
self.edges[idx_or_node] = set()
|
||||||
return self.edges[idx_or_node]
|
return self.edges[idx_or_node]
|
||||||
|
|
||||||
def add_node(self, c, *, _name=None):
|
def add_node(self, new_node, *, _name=None):
|
||||||
"""Add a node without connections in this graph and returns its index.
|
"""Add a node without connections in this graph and returns its index.
|
||||||
If _name is specified, name this node (string reference for further usage).
|
If _name is specified, name this node (string reference for further usage).
|
||||||
"""
|
"""
|
||||||
idx = len(self.nodes)
|
idx = len(self.nodes)
|
||||||
self.edges[idx] = set()
|
self.edges[idx] = set()
|
||||||
self.nodes.append(c)
|
self.nodes.append(new_node)
|
||||||
|
|
||||||
if _name:
|
if _name:
|
||||||
if _name in self.named:
|
if _name in self.named:
|
||||||
@ -120,7 +141,14 @@ class Graph:
|
|||||||
|
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None):
|
def get_or_add_node(self, new_node, *, _name=None):
|
||||||
|
if new_node in self.nodes:
|
||||||
|
if _name is not None:
|
||||||
|
raise RuntimeError("Cannot name a node that is already present in the graph.")
|
||||||
|
return self.index_of(new_node)
|
||||||
|
return self.add_node(new_node, _name=_name)
|
||||||
|
|
||||||
|
def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None, use_existing_nodes=False):
|
||||||
"""Add `nodes` as a chain in this graph.
|
"""Add `nodes` as a chain in this graph.
|
||||||
|
|
||||||
**Input rules**
|
**Input rules**
|
||||||
@ -153,6 +181,8 @@ class Graph:
|
|||||||
_first = None
|
_first = None
|
||||||
_last = None
|
_last = None
|
||||||
|
|
||||||
|
get_node = self.get_or_add_node if use_existing_nodes else self.add_node
|
||||||
|
|
||||||
# Sanity checks.
|
# Sanity checks.
|
||||||
if not len(nodes):
|
if not len(nodes):
|
||||||
if _input is None or _output is None:
|
if _input is None or _output is None:
|
||||||
@ -164,7 +194,7 @@ class Graph:
|
|||||||
raise RuntimeError("Using add_chain(...) without nodes does not allow to use the _name parameter.")
|
raise RuntimeError("Using add_chain(...) without nodes does not allow to use the _name parameter.")
|
||||||
|
|
||||||
for i, node in enumerate(nodes):
|
for i, node in enumerate(nodes):
|
||||||
_last = self.add_node(node, _name=_name if not i else None)
|
_last = get_node(node, _name=_name if not i else None)
|
||||||
|
|
||||||
if _first is None:
|
if _first is None:
|
||||||
_first = _last
|
_first = _last
|
||||||
|
|||||||
129
tests/structs/test_graphs_new_syntax.py
Normal file
129
tests/structs/test_graphs_new_syntax.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
from operator import attrgetter
|
||||||
|
from unittest.mock import sentinel
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from bonobo.constants import BEGIN
|
||||||
|
from bonobo.structs.graphs import Graph, GraphCursor
|
||||||
|
from bonobo.util import tuplize
|
||||||
|
|
||||||
|
|
||||||
|
@tuplize
|
||||||
|
def get_pseudo_nodes(*names):
|
||||||
|
for name in names:
|
||||||
|
yield getattr(sentinel, name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cursor():
|
||||||
|
g = Graph()
|
||||||
|
cursor = g.get_cursor()
|
||||||
|
|
||||||
|
assert cursor.graph is g
|
||||||
|
assert cursor.first is BEGIN
|
||||||
|
assert cursor.last is BEGIN
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cursor_in_a_vacuum():
|
||||||
|
g = Graph()
|
||||||
|
cursor = g.get_cursor(None)
|
||||||
|
|
||||||
|
assert cursor.graph is g
|
||||||
|
assert cursor.first is None
|
||||||
|
assert cursor.last is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cursor_usage_to_add_a_chain():
|
||||||
|
a, b, c = get_pseudo_nodes(*"abc")
|
||||||
|
|
||||||
|
g = Graph()
|
||||||
|
|
||||||
|
g.get_cursor() >> a >> b >> c
|
||||||
|
|
||||||
|
assert len(g) == 3
|
||||||
|
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||||
|
assert g.outputs_of(a) == {g.index_of(b)}
|
||||||
|
assert g.outputs_of(b) == {g.index_of(c)}
|
||||||
|
assert g.outputs_of(c) == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cursor_usage_to_add_a_chain_in_a_context_manager():
|
||||||
|
a, b, c = get_pseudo_nodes(*"abc")
|
||||||
|
|
||||||
|
g = Graph()
|
||||||
|
with g as cur:
|
||||||
|
cur >> a >> b >> c
|
||||||
|
|
||||||
|
assert len(g) == 3
|
||||||
|
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||||
|
assert g.outputs_of(a) == {g.index_of(b)}
|
||||||
|
assert g.outputs_of(b) == {g.index_of(c)}
|
||||||
|
assert g.outputs_of(c) == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_implicit_cursor_usage():
|
||||||
|
a, b, c = get_pseudo_nodes(*"abc")
|
||||||
|
|
||||||
|
g = Graph()
|
||||||
|
g >> a >> b >> c
|
||||||
|
|
||||||
|
assert len(g) == 3
|
||||||
|
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||||
|
assert g.outputs_of(a) == {g.index_of(b)}
|
||||||
|
assert g.outputs_of(b) == {g.index_of(c)}
|
||||||
|
assert g.outputs_of(c) == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cursor_to_fork_a_graph():
|
||||||
|
a, b, c, d, e = get_pseudo_nodes(*"abcde")
|
||||||
|
|
||||||
|
g = Graph()
|
||||||
|
g >> a >> b >> c
|
||||||
|
g.get_cursor(b) >> d >> e
|
||||||
|
|
||||||
|
assert len(g) == 5
|
||||||
|
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||||
|
assert g.outputs_of(a) == {g.index_of(b)}
|
||||||
|
assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)}
|
||||||
|
assert g.outputs_of(c) == set()
|
||||||
|
assert g.outputs_of(d) == {g.index_of(e)}
|
||||||
|
assert g.outputs_of(e) == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cursor_to_fork_at_the_end():
|
||||||
|
a, b, c, d, e = get_pseudo_nodes(*"abcde")
|
||||||
|
|
||||||
|
g = Graph()
|
||||||
|
c0 = g >> a >> b
|
||||||
|
c1 = c0 >> c
|
||||||
|
c2 = c0 >> d >> e
|
||||||
|
|
||||||
|
assert len(g) == 5
|
||||||
|
assert g.outputs_of(BEGIN) == {g.index_of(a)}
|
||||||
|
assert g.outputs_of(a) == {g.index_of(b)}
|
||||||
|
assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)}
|
||||||
|
assert g.outputs_of(c) == set()
|
||||||
|
assert g.outputs_of(d) == {g.index_of(e)}
|
||||||
|
assert g.outputs_of(e) == set()
|
||||||
|
|
||||||
|
assert c0.first == g.index_of(BEGIN)
|
||||||
|
assert c0.last == g.index_of(b)
|
||||||
|
assert c1.first == g.index_of(BEGIN)
|
||||||
|
assert c1.last == g.index_of(c)
|
||||||
|
assert c2.first == g.index_of(BEGIN)
|
||||||
|
assert c2.last == g.index_of(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cursor_merge():
|
||||||
|
a, b, c = get_pseudo_nodes(*"abc")
|
||||||
|
g = Graph()
|
||||||
|
|
||||||
|
c1 = g >> a >> c
|
||||||
|
c2 = g >> b >> c
|
||||||
|
|
||||||
|
assert len(g) == 3
|
||||||
|
assert g.outputs_of(BEGIN) == g.indexes_of(a, b)
|
||||||
|
assert g.outputs_of(a) == g.indexes_of(c)
|
||||||
|
assert g.outputs_of(b) == g.indexes_of(c)
|
||||||
|
assert g.outputs_of(c) == set()
|
||||||
|
|
||||||
|
assert c1 == c2
|
||||||
Reference in New Issue
Block a user