refactoring for better testability

This commit is contained in:
Romain Dorgueil
2016-12-25 15:01:56 +01:00
parent deb7700353
commit 1fbd43a94d
4 changed files with 71 additions and 53 deletions

View File

@ -35,6 +35,12 @@ class ExecutionContext:
def __iter__(self): def __iter__(self):
yield from self.components yield from self.components
def impulse(self):
for i in self.graph.outputs_of(Begin):
self[i].recv(Begin)
self[i].recv(Bag())
self[i].recv(End)
@property @property
def running(self): def running(self):
return any(component.running for component in self.components) return any(component.running for component in self.components)
@ -77,6 +83,23 @@ def _iter(x):
return iter(x) return iter(x)
def _resolve(input_bag, output):
# NotModified means to send the input unmodified to output.
if output is NotModified:
return input_bag
# If it does not look like a bag, let's create one for easier manipulation
if hasattr(output, 'apply'):
# Already a bag? Check if we need to set parent.
if InheritInputFlag in output.flags:
output.set_parent(input_bag)
else:
# Not a bag? Let's encapsulate it.
output = Bag(output)
return output
class ComponentExecutionContext(WithStatistics): class ComponentExecutionContext(WithStatistics):
""" """
todo: make the counter dependant of parent context? todo: make the counter dependant of parent context?
@ -149,34 +172,15 @@ class ComponentExecutionContext(WithStatistics):
output channel.""" output channel."""
input_bag = self.get() input_bag = self.get()
outputs = self._call(input_bag)
def _resolve(output):
nonlocal input_bag
# NotModified means to send the input unmodified to output.
if output is NotModified:
return input_bag
# If it does not look like a bag, let's create one for easier manipulation
if hasattr(output, 'apply'):
# Already a bag? Check if we need to set parent.
if InheritInputFlag in output.flags:
output.set_parent(input_bag)
else:
# Not a bag? Let's encapsulate it.
output = Bag(result)
return output
results = self._call(input_bag)
# self._exec_time += timer.duration # self._exec_time += timer.duration
# Put data onto output channels # Put data onto output channels
try: try:
results = _iter(results) outputs = _iter(outputs)
except TypeError: except TypeError:
if results: if outputs:
self.send(_resolve(results)) self.send(_resolve(input_bag, outputs))
else: else:
# case with no result, an execution went through anyway, use for stats. # case with no result, an execution went through anyway, use for stats.
# self._exec_count += 1 # self._exec_count += 1
@ -184,10 +188,10 @@ class ComponentExecutionContext(WithStatistics):
else: else:
while True: while True:
try: try:
result = next(results) output = next(outputs)
except StopIteration as e: except StopIteration as e:
break break
self.send(_resolve(result)) self.send(_resolve(input_bag, output))
def run(self): def run(self):
assert self.state is New, ('A {} can only be run once, and thus is expected to be in {} state at the ' assert self.state is New, ('A {} can only be run once, and thus is expected to be in {} state at the '

View File

@ -18,12 +18,9 @@ class ExecutorStrategy(Strategy):
def execute(self, graph, *args, plugins=None, **kwargs): def execute(self, graph, *args, plugins=None, **kwargs):
context = self.create_context(graph, plugins=plugins) context = self.create_context(graph, plugins=plugins)
executor = self.executor_factory() context.impulse()
for i in graph.outputs_of(Begin): executor = self.executor_factory()
context[i].recv(Begin)
context[i].recv(Bag())
context[i].recv(End)
futures = [] futures = []
@ -41,8 +38,7 @@ class ExecutorStrategy(Strategy):
executor.shutdown() executor.shutdown()
#for component_context in context.components: return context
# print(component_context)
class ThreadPoolExecutorStrategy(ExecutorStrategy): class ThreadPoolExecutorStrategy(ExecutorStrategy):

View File

@ -1,21 +1,14 @@
from queue import Queue, Empty
from bonobo.core.strategies.base import Strategy from bonobo.core.strategies.base import Strategy
from bonobo.util.iterators import force_iterator
class NaiveStrategy(Strategy): class NaiveStrategy(Strategy):
def execute(self, graph, *args, **kwargs): def execute(self, graph, *args, plugins=None, **kwargs):
context = self.create_context(graph) context = self.create_context(graph, plugins=plugins)
context.impulse()
input_queues = {i: Queue() for i in range(len(context.graph.components))} # TODO: how to run plugins in "naive" mode ?
for i, component in enumerate(context.graph.components):
while True: for component in context.components:
try: component.run()
args = (input_queues[i].get(block=False), ) if i else ()
for row in force_iterator(component(*args)): return context
input_queues[i + 1].put(row)
if not i:
raise Empty
except Empty:
break

View File

@ -1,5 +1,6 @@
from bonobo import Graph from bonobo import Graph, NaiveStrategy
from bonobo.core.contexts import ExecutionContext from bonobo.core.contexts import ExecutionContext
from bonobo.util.lifecycle import with_context
def generate_integers(): def generate_integers():
@ -10,6 +11,16 @@ def square(i: int) -> int:
return i**2 return i**2
@with_context
def push_result(ctx, i: int):
if not hasattr(ctx.parent, 'results'):
ctx.parent.results = []
ctx.parent.results.append(i)
chain = (generate_integers, square, push_result)
def test_empty_execution_context(): def test_empty_execution_context():
graph = Graph() graph = Graph()
@ -20,15 +31,29 @@ def test_empty_execution_context():
assert not ctx.running assert not ctx.running
def test_execution():
graph = Graph()
graph.add_chain(*chain)
strategy = NaiveStrategy()
ctx = strategy.execute(graph)
assert ctx.results == [1, 4, 9, 16, 25, 36, 49, 64, 81]
def test_simple_execution_context(): def test_simple_execution_context():
graph = Graph() graph = Graph()
graph.add_chain(generate_integers, square) graph.add_chain(*chain)
ctx = ExecutionContext(graph) ctx = ExecutionContext(graph)
assert len(ctx.components) == 2 assert len(ctx.components) == len(chain)
assert not len(ctx.plugins) assert not len(ctx.plugins)
assert ctx[0].component is generate_integers for i, component in enumerate(chain):
assert ctx[1].component is square assert ctx[i].component is component
assert not ctx.running assert not ctx.running
ctx.impulse()
assert ctx.running