diff --git a/bonobo/structs/graphs.py b/bonobo/structs/graphs.py index b74baae..ea1ccf9 100644 --- a/bonobo/structs/graphs.py +++ b/bonobo/structs/graphs.py @@ -12,6 +12,15 @@ from bonobo.util import get_name GraphRange = namedtuple("GraphRange", ["graph", "input", "output"]) +def coalesce(*values): + if not len(values): + raise ValueError("Cannot coalesce an empty list of arguments.") + for value in values: + if value is not None: + return value + return values[-1] + + class GraphCursor: @property def input(self): @@ -21,9 +30,13 @@ class GraphCursor: def output(self): return self.last + @property + def range(self): + return self.first, self.last + def __init__(self, graph, *, first=None, last=None): self.graph = graph - self.first = first or last + self.first = coalesce(first, last) self.last = last def __rshift__(self, other): @@ -46,7 +59,7 @@ class GraphCursor: # If there are nodes to add, create a new cursor after the chain is added to the graph. if len(nodes): chain = self.graph.add_chain(*nodes, _input=self.last, use_existing_nodes=True) - return GraphCursor(chain.graph, first=self.first, last=chain.output) + return GraphCursor(chain.graph, first=coalesce(self.first, chain.input), last=chain.output) # If we add nothing, then nothing changed. return self @@ -143,12 +156,12 @@ class Graph: raise ValueError("Cannot find node matching {!r}.".format(mixed)) - def indexes_of(self, *things): + def indexes_of(self, *things, _type=set): """ Returns the set of indexes of the things passed as arguments. """ - return set(map(self.index_of, things)) + return _type(map(self.index_of, things)) def outputs_of(self, idx_or_node, create=False): """ diff --git a/tests/structs/test_graphs_new_syntax.py b/tests/structs/test_graphs_new_syntax.py index 1740ab0..d29d78c 100644 --- a/tests/structs/test_graphs_new_syntax.py +++ b/tests/structs/test_graphs_new_syntax.py @@ -168,5 +168,53 @@ def test_concat_branches(): c1 = g >> c >> d c2 = c1 >> c0 + assert c0.first == g.index_of(a) assert c2.first == BEGIN assert c2.last == g.index_of(b) + + assert g.outputs_of(BEGIN) == g.indexes_of(c) + assert g.outputs_of(a) == g.indexes_of(b) + assert g.outputs_of(b) == set() + assert g.outputs_of(c) == g.indexes_of(d) + assert g.outputs_of(d) == g.indexes_of(a) + + +def test_add_branch_inbetween(): + a, b, c, d, e, f = get_pseudo_nodes(6) + g = Graph() + c0 = g.orphan() >> a >> b + c1 = g.orphan() >> c >> d + c2 = g.orphan() >> e >> f + c3 = c0 >> c1 >> c2 + + assert c0.range == g.indexes_of(a, b, _type=tuple) + assert c1.range == g.indexes_of(c, d, _type=tuple) + assert c2.range == g.indexes_of(e, f, _type=tuple) + assert c3.range == g.indexes_of(a, f, _type=tuple) + + assert g.outputs_of(b) == g.indexes_of(c) + assert g.outputs_of(d) == g.indexes_of(e) + assert g.outputs_of(f) == set() + + +def test_add_more_branches_inbetween(): + a, b, c, d, e, f, x, y = get_pseudo_nodes(8) + g = Graph() + c0 = g.orphan() >> a >> b + c1 = g.orphan() >> c >> d + c2 = g.orphan() >> e >> f + c3 = g.orphan() >> x >> y + c4 = c0 >> c1 >> c3 + c5 = c0 >> c2 >> c3 + + assert c0.range == g.indexes_of(a, b, _type=tuple) + assert c1.range == g.indexes_of(c, d, _type=tuple) + assert c2.range == g.indexes_of(e, f, _type=tuple) + assert c3.range == g.indexes_of(x, y, _type=tuple) + assert c4.range == g.indexes_of(a, y, _type=tuple) + assert c5.range == g.indexes_of(a, y, _type=tuple) + + assert g.outputs_of(b) == g.indexes_of(c, e) + assert g.outputs_of(d) == g.indexes_of(x) + assert g.outputs_of(f) == g.indexes_of(x) + assert g.outputs_of(y) == set()