refactoring for better testability

This commit is contained in:
Romain Dorgueil
2016-12-25 15:01:56 +01:00
parent deb7700353
commit 1fbd43a94d
4 changed files with 71 additions and 53 deletions

View File

@ -35,6 +35,12 @@ class ExecutionContext:
def __iter__(self):
yield from self.components
def impulse(self):
for i in self.graph.outputs_of(Begin):
self[i].recv(Begin)
self[i].recv(Bag())
self[i].recv(End)
@property
def running(self):
return any(component.running for component in self.components)
@ -77,6 +83,23 @@ def _iter(x):
return iter(x)
def _resolve(input_bag, output):
# NotModified means to send the input unmodified to output.
if output is NotModified:
return input_bag
# If it does not look like a bag, let's create one for easier manipulation
if hasattr(output, 'apply'):
# Already a bag? Check if we need to set parent.
if InheritInputFlag in output.flags:
output.set_parent(input_bag)
else:
# Not a bag? Let's encapsulate it.
output = Bag(output)
return output
class ComponentExecutionContext(WithStatistics):
"""
todo: make the counter dependant of parent context?
@ -149,34 +172,15 @@ class ComponentExecutionContext(WithStatistics):
output channel."""
input_bag = self.get()
def _resolve(output):
nonlocal input_bag
# NotModified means to send the input unmodified to output.
if output is NotModified:
return input_bag
# If it does not look like a bag, let's create one for easier manipulation
if hasattr(output, 'apply'):
# Already a bag? Check if we need to set parent.
if InheritInputFlag in output.flags:
output.set_parent(input_bag)
else:
# Not a bag? Let's encapsulate it.
output = Bag(result)
return output
results = self._call(input_bag)
outputs = self._call(input_bag)
# self._exec_time += timer.duration
# Put data onto output channels
try:
results = _iter(results)
outputs = _iter(outputs)
except TypeError:
if results:
self.send(_resolve(results))
if outputs:
self.send(_resolve(input_bag, outputs))
else:
# case with no result, an execution went through anyway, use for stats.
# self._exec_count += 1
@ -184,10 +188,10 @@ class ComponentExecutionContext(WithStatistics):
else:
while True:
try:
result = next(results)
output = next(outputs)
except StopIteration as e:
break
self.send(_resolve(result))
self.send(_resolve(input_bag, output))
def run(self):
assert self.state is New, ('A {} can only be run once, and thus is expected to be in {} state at the '

View File

@ -18,12 +18,9 @@ class ExecutorStrategy(Strategy):
def execute(self, graph, *args, plugins=None, **kwargs):
context = self.create_context(graph, plugins=plugins)
executor = self.executor_factory()
context.impulse()
for i in graph.outputs_of(Begin):
context[i].recv(Begin)
context[i].recv(Bag())
context[i].recv(End)
executor = self.executor_factory()
futures = []
@ -41,8 +38,7 @@ class ExecutorStrategy(Strategy):
executor.shutdown()
#for component_context in context.components:
# print(component_context)
return context
class ThreadPoolExecutorStrategy(ExecutorStrategy):

View File

@ -1,21 +1,14 @@
from queue import Queue, Empty
from bonobo.core.strategies.base import Strategy
from bonobo.util.iterators import force_iterator
class NaiveStrategy(Strategy):
def execute(self, graph, *args, **kwargs):
context = self.create_context(graph)
def execute(self, graph, *args, plugins=None, **kwargs):
context = self.create_context(graph, plugins=plugins)
context.impulse()
input_queues = {i: Queue() for i in range(len(context.graph.components))}
for i, component in enumerate(context.graph.components):
while True:
try:
args = (input_queues[i].get(block=False), ) if i else ()
for row in force_iterator(component(*args)):
input_queues[i + 1].put(row)
if not i:
raise Empty
except Empty:
break
# TODO: how to run plugins in "naive" mode ?
for component in context.components:
component.run()
return context

View File

@ -1,5 +1,6 @@
from bonobo import Graph
from bonobo import Graph, NaiveStrategy
from bonobo.core.contexts import ExecutionContext
from bonobo.util.lifecycle import with_context
def generate_integers():
@ -10,6 +11,16 @@ def square(i: int) -> int:
return i**2
@with_context
def push_result(ctx, i: int):
if not hasattr(ctx.parent, 'results'):
ctx.parent.results = []
ctx.parent.results.append(i)
chain = (generate_integers, square, push_result)
def test_empty_execution_context():
graph = Graph()
@ -20,15 +31,29 @@ def test_empty_execution_context():
assert not ctx.running
def test_execution():
graph = Graph()
graph.add_chain(*chain)
strategy = NaiveStrategy()
ctx = strategy.execute(graph)
assert ctx.results == [1, 4, 9, 16, 25, 36, 49, 64, 81]
def test_simple_execution_context():
graph = Graph()
graph.add_chain(generate_integers, square)
graph.add_chain(*chain)
ctx = ExecutionContext(graph)
assert len(ctx.components) == 2
assert len(ctx.components) == len(chain)
assert not len(ctx.plugins)
assert ctx[0].component is generate_integers
assert ctx[1].component is square
for i, component in enumerate(chain):
assert ctx[i].component is component
assert not ctx.running
ctx.impulse()
assert ctx.running