diff --git a/bonobo/structs/graphs.py b/bonobo/structs/graphs.py index 8d256de..aaf3fd6 100644 --- a/bonobo/structs/graphs.py +++ b/bonobo/structs/graphs.py @@ -37,11 +37,20 @@ class GraphCursor: ) if len(nodes): - chain = self.graph.add_chain(*nodes, _input=self.last) + chain = self.graph.add_chain(*nodes, _input=self.last, use_existing_nodes=True) return GraphCursor(chain.graph, first=self.first, last=chain.output) return self + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return None + + def __eq__(self, other): + return self.graph == other.graph and self.first == other.first and self.last == other.last + class PartialGraph: def __init__(self, *nodes): @@ -73,6 +82,15 @@ class Graph: def __getitem__(self, key): return self.nodes[key] + def __enter__(self): + return self.get_cursor().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return None + + def __rshift__(self, other): + return self.get_cursor().__rshift__(other) + def get_cursor(self, ref=BEGIN): return GraphCursor(self, last=self.index_of(ref)) @@ -96,6 +114,9 @@ class Graph: raise ValueError("Cannot find node matching {!r}.".format(mixed)) + def indexes_of(self, *things): + return set(map(self.index_of, things)) + def outputs_of(self, idx_or_node, create=False): """Get a set of the outputs for a given node, node index or name. """ @@ -105,13 +126,13 @@ class Graph: self.edges[idx_or_node] = set() return self.edges[idx_or_node] - def add_node(self, c, *, _name=None): + def add_node(self, new_node, *, _name=None): """Add a node without connections in this graph and returns its index. If _name is specified, name this node (string reference for further usage). """ idx = len(self.nodes) self.edges[idx] = set() - self.nodes.append(c) + self.nodes.append(new_node) if _name: if _name in self.named: @@ -120,7 +141,14 @@ class Graph: return idx - def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None): + def get_or_add_node(self, new_node, *, _name=None): + if new_node in self.nodes: + if _name is not None: + raise RuntimeError("Cannot name a node that is already present in the graph.") + return self.index_of(new_node) + return self.add_node(new_node, _name=_name) + + def add_chain(self, *nodes, _input=BEGIN, _output=None, _name=None, use_existing_nodes=False): """Add `nodes` as a chain in this graph. **Input rules** @@ -153,6 +181,8 @@ class Graph: _first = None _last = None + get_node = self.get_or_add_node if use_existing_nodes else self.add_node + # Sanity checks. if not len(nodes): if _input is None or _output is None: @@ -164,7 +194,7 @@ class Graph: raise RuntimeError("Using add_chain(...) without nodes does not allow to use the _name parameter.") for i, node in enumerate(nodes): - _last = self.add_node(node, _name=_name if not i else None) + _last = get_node(node, _name=_name if not i else None) if _first is None: _first = _last diff --git a/tests/structs/test_graphs_new_syntax.py b/tests/structs/test_graphs_new_syntax.py new file mode 100644 index 0000000..570fa47 --- /dev/null +++ b/tests/structs/test_graphs_new_syntax.py @@ -0,0 +1,129 @@ +from operator import attrgetter +from unittest.mock import sentinel + +import pytest + +from bonobo.constants import BEGIN +from bonobo.structs.graphs import Graph, GraphCursor +from bonobo.util import tuplize + + +@tuplize +def get_pseudo_nodes(*names): + for name in names: + yield getattr(sentinel, name) + + +def test_get_cursor(): + g = Graph() + cursor = g.get_cursor() + + assert cursor.graph is g + assert cursor.first is BEGIN + assert cursor.last is BEGIN + + +def test_get_cursor_in_a_vacuum(): + g = Graph() + cursor = g.get_cursor(None) + + assert cursor.graph is g + assert cursor.first is None + assert cursor.last is None + + +def test_cursor_usage_to_add_a_chain(): + a, b, c = get_pseudo_nodes(*"abc") + + g = Graph() + + g.get_cursor() >> a >> b >> c + + assert len(g) == 3 + assert g.outputs_of(BEGIN) == {g.index_of(a)} + assert g.outputs_of(a) == {g.index_of(b)} + assert g.outputs_of(b) == {g.index_of(c)} + assert g.outputs_of(c) == set() + + +def test_cursor_usage_to_add_a_chain_in_a_context_manager(): + a, b, c = get_pseudo_nodes(*"abc") + + g = Graph() + with g as cur: + cur >> a >> b >> c + + assert len(g) == 3 + assert g.outputs_of(BEGIN) == {g.index_of(a)} + assert g.outputs_of(a) == {g.index_of(b)} + assert g.outputs_of(b) == {g.index_of(c)} + assert g.outputs_of(c) == set() + + +def test_implicit_cursor_usage(): + a, b, c = get_pseudo_nodes(*"abc") + + g = Graph() + g >> a >> b >> c + + assert len(g) == 3 + assert g.outputs_of(BEGIN) == {g.index_of(a)} + assert g.outputs_of(a) == {g.index_of(b)} + assert g.outputs_of(b) == {g.index_of(c)} + assert g.outputs_of(c) == set() + + +def test_cursor_to_fork_a_graph(): + a, b, c, d, e = get_pseudo_nodes(*"abcde") + + g = Graph() + g >> a >> b >> c + g.get_cursor(b) >> d >> e + + assert len(g) == 5 + assert g.outputs_of(BEGIN) == {g.index_of(a)} + assert g.outputs_of(a) == {g.index_of(b)} + assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)} + assert g.outputs_of(c) == set() + assert g.outputs_of(d) == {g.index_of(e)} + assert g.outputs_of(e) == set() + + +def test_cursor_to_fork_at_the_end(): + a, b, c, d, e = get_pseudo_nodes(*"abcde") + + g = Graph() + c0 = g >> a >> b + c1 = c0 >> c + c2 = c0 >> d >> e + + assert len(g) == 5 + assert g.outputs_of(BEGIN) == {g.index_of(a)} + assert g.outputs_of(a) == {g.index_of(b)} + assert g.outputs_of(b) == {g.index_of(c), g.index_of(d)} + assert g.outputs_of(c) == set() + assert g.outputs_of(d) == {g.index_of(e)} + assert g.outputs_of(e) == set() + + assert c0.first == g.index_of(BEGIN) + assert c0.last == g.index_of(b) + assert c1.first == g.index_of(BEGIN) + assert c1.last == g.index_of(c) + assert c2.first == g.index_of(BEGIN) + assert c2.last == g.index_of(e) + + +def test_cursor_merge(): + a, b, c = get_pseudo_nodes(*"abc") + g = Graph() + + c1 = g >> a >> c + c2 = g >> b >> c + + assert len(g) == 3 + assert g.outputs_of(BEGIN) == g.indexes_of(a, b) + assert g.outputs_of(a) == g.indexes_of(c) + assert g.outputs_of(b) == g.indexes_of(c) + assert g.outputs_of(c) == set() + + assert c1 == c2