New syntax: adding test cases and syntactic sugar tools in graph api for merges and forks (related to #323 and #324, allows #328).

This commit is contained in:
Romain Dorgueil
2019-06-01 12:31:38 +02:00
parent 0d17881928
commit c998708923
2 changed files with 164 additions and 5 deletions

View File

@ -37,11 +37,20 @@ class GraphCursor:
)
if len(nodes):
chain = self.graph.add_chain(*nodes, _input=self.last)
chain = self.graph.add_chain(*nodes, _input=self.last, use_existing_nodes=True)
return GraphCursor(chain.graph, first=self.first, last=chain.output)
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return None
def __eq__(self, other):
return self.graph == other.graph and self.first == other.first and self.last == other.last
class PartialGraph:
def __init__(self, *nodes):
@ -73,6 +82,15 @@ class Graph:
def __getitem__(self, key):
return self.nodes[key]
def __enter__(self):
return self.get_cursor().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
return None
def __rshift__(self, other):
return self.get_cursor().__rshift__(other)
def get_cursor(self, ref=BEGIN):
return GraphCursor(self, last=self.index_of(ref))
@ -96,6 +114,9 @@ class Graph:
raise ValueError("Cannot find node matching {!r}.".format(mixed))
def indexes_of(self, *things):
return set(map(self.index_of, things))
def outputs_of(self, idx_or_node, create=False):
"""Get a set of the outputs for a given node, node index or name.
"""
@ -105,13 +126,13 @@ class Graph:
self.edges[idx_or_node] = set()
return self.edges[idx_or_node]
def add_node(self, c, *, _name=None):
def add_node(self, new_node, *, _name=None):
"""Add a node without connections in this graph and returns its index.
If _name is specified, name this node (string reference for further usage).
"""
idx = len(self.nodes)
self.edges[idx] = set()
self.nodes.append(c)
self.nodes.append(new_node)
if _name:
if _name in self.named:
@ -120,7 +141,14 @@ class Graph:
return idx
def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None):
def get_or_add_node(self, new_node, *, _name=None):
if new_node in self.nodes:
if _name is not None:
raise RuntimeError("Cannot name a node that is already present in the graph.")
return self.index_of(new_node)
return self.add_node(new_node, _name=_name)
def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None, use_existing_nodes=False):
"""Add `nodes` as a chain in this graph.
**Input rules**
@ -153,6 +181,8 @@ class Graph:
_first = None
_last = None
get_node = self.get_or_add_node if use_existing_nodes else self.add_node
# Sanity checks.
if not len(nodes):
if _input is None or _output is None:
@ -164,7 +194,7 @@ class Graph:
raise RuntimeError("Using add_chain(...) without nodes does not allow to use the _name parameter.")
for i, node in enumerate(nodes):
_last = self.add_node(node, _name=_name if not i else None)
_last = get_node(node, _name=_name if not i else None)
if _first is None:
_first = _last

View File

@ -0,0 +1,129 @@
from operator import attrgetter
from unittest.mock import sentinel
import pytest
from bonobo.constants import BEGIN
from bonobo.structs.graphs import Graph, GraphCursor
from bonobo.util import tuplize
@tuplize
def get_pseudo_nodes(*names):
for name in names:
yield getattr(sentinel, name)
def test_get_cursor():
g = Graph()
cursor = g.get_cursor()
assert cursor.graph is g
assert cursor.first is BEGIN
assert cursor.last is BEGIN
def test_get_cursor_in_a_vacuum():
g = Graph()
cursor = g.get_cursor(None)
assert cursor.graph is g
assert cursor.first is None
assert cursor.last is None
def test_cursor_usage_to_add_a_chain():
a, b, c = get_pseudo_nodes(*"abc")
g = Graph()
g.get_cursor() >> a >> b >> c
assert len(g) == 3
assert g.outputs_of(BEGIN) == {g.index_of(a)}
assert g.outputs_of(a) == {g.index_of(b)}
assert g.outputs_of(b) == {g.index_of(c)}
assert g.outputs_of(c) == set()
def test_cursor_usage_to_add_a_chain_in_a_context_manager():
a, b, c = get_pseudo_nodes(*"abc")
g = Graph()
with g as cur:
cur >> a >> b >> c
assert len(g) == 3
assert g.outputs_of(BEGIN) == {g.index_of(a)}
assert g.outputs_of(a) == {g.index_of(b)}
assert g.outputs_of(b) == {g.index_of(c)}
assert g.outputs_of(c) == set()
def test_implicit_cursor_usage():
a, b, c = get_pseudo_nodes(*"abc")
g = Graph()
g >> a >> b >> c
assert len(g) == 3
assert g.outputs_of(BEGIN) == {g.index_of(a)}
assert g.outputs_of(a) == {g.index_of(b)}
assert g.outputs_of(b) == {g.index_of(c)}
assert g.outputs_of(c) == set()
def test_cursor_to_fork_a_graph():
a, b, c, d, e = get_pseudo_nodes(*"abcde")
g = Graph()
g >> a >> b >> c
g.get_cursor(b) >> d >> e
assert len(g) == 5
assert g.outputs_of(BEGIN) == {g.index_of(a)}
assert g.outputs_of(a) == {g.index_of(b)}
assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)}
assert g.outputs_of(c) == set()
assert g.outputs_of(d) == {g.index_of(e)}
assert g.outputs_of(e) == set()
def test_cursor_to_fork_at_the_end():
a, b, c, d, e = get_pseudo_nodes(*"abcde")
g = Graph()
c0 = g >> a >> b
c1 = c0 >> c
c2 = c0 >> d >> e
assert len(g) == 5
assert g.outputs_of(BEGIN) == {g.index_of(a)}
assert g.outputs_of(a) == {g.index_of(b)}
assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)}
assert g.outputs_of(c) == set()
assert g.outputs_of(d) == {g.index_of(e)}
assert g.outputs_of(e) == set()
assert c0.first == g.index_of(BEGIN)
assert c0.last == g.index_of(b)
assert c1.first == g.index_of(BEGIN)
assert c1.last == g.index_of(c)
assert c2.first == g.index_of(BEGIN)
assert c2.last == g.index_of(e)
def test_cursor_merge():
a, b, c = get_pseudo_nodes(*"abc")
g = Graph()
c1 = g >> a >> c
c2 = g >> b >> c
assert len(g) == 3
assert g.outputs_of(BEGIN) == g.indexes_of(a, b)
assert g.outputs_of(a) == g.indexes_of(c)
assert g.outputs_of(b) == g.indexes_of(c)
assert g.outputs_of(c) == set()
assert c1 == c2