diff --git a/bonobo/structs/graphs.py b/bonobo/structs/graphs.py index ce43ef6..b74baae 100644 --- a/bonobo/structs/graphs.py +++ b/bonobo/structs/graphs.py @@ -29,17 +29,26 @@ class GraphCursor: def __rshift__(self, other): """ Self >> Other """ + # Allow to concatenate cursors. + if isinstance(other, GraphCursor): + chain = self.graph.add_chain(_input=self.last, _output=other.first) + return GraphCursor(chain.graph, first=self.first, last=other.last) + + # If we get a partial graph, or anything with a node list, use that. nodes = other.nodes if hasattr(other, "nodes") else [other] + # Sometimes, we use ellipsis to show "pseudo-code". This is ok, but can't be executed. if ... in nodes: raise NotImplementedError( "Expected something looking like a node, but got an Ellipsis (...). Did you forget to complete the graph?" ) + # If there are nodes to add, create a new cursor after the chain is added to the graph. if len(nodes): chain = self.graph.add_chain(*nodes, _input=self.last, use_existing_nodes=True) return GraphCursor(chain.graph, first=self.first, last=chain.output) + # If we add nothing, then nothing changed. return self def __enter__(self): @@ -49,7 +58,10 @@ class GraphCursor: return None def __eq__(self, other): - return self.graph == other.graph and self.first == other.first and self.last == other.last + try: + return self.graph == other.graph and self.first == other.first and self.last == other.last + except AttributeError: + return False class PartialGraph: diff --git a/tests/structs/test_graphs_new_syntax.py b/tests/structs/test_graphs_new_syntax.py index dd8b355..1740ab0 100644 --- a/tests/structs/test_graphs_new_syntax.py +++ b/tests/structs/test_graphs_new_syntax.py @@ -158,3 +158,15 @@ def test_using_same_cursor_many_times_for_fork(): assert g.outputs_of(c) == set() assert g.outputs_of(d) == set() assert g.outputs_of(e) == set() + + +def test_concat_branches(): + a, b, c, d = get_pseudo_nodes(4) + g = Graph() + + c0 = g.orphan() >> a >> b + c1 = g >> c >> d + c2 = c1 >> c0 + + assert c2.first == BEGIN + assert c2.last == g.index_of(b)