Big refactoring, way simpler management of transformations. Early branch for upcoming version 0.2.

This commit is contained in:
Romain Dorgueil
2017-01-17 22:45:10 +01:00
parent b6e84c66e3
commit 8532520aae
41 changed files with 627 additions and 433 deletions

1
.gitignore vendored
View File

@ -23,6 +23,7 @@
.tox/ .tox/
.webassets-cache .webassets-cache
/.idea /.idea
/.release
/bonobo.iml /bonobo.iml
/bonobo/ext/jupyter/js/node_modules/ /bonobo/ext/jupyter/js/node_modules/
/build/ /build/

70
ROADMAP.rst Normal file
View File

@ -0,0 +1,70 @@
Roadmap
=======
Milestones
::::::::::
* Class-tree for Graph and Nodes
* Class-tree for execution contexts:
* GraphExecutionContext
* NodeExecutionContext
* PluginExecutionContext
* Class-tree for ExecutionStrategies
* NaiveStrategy
* PoolExecutionStrategy
* ThreadPoolExecutionStrategy
* ProcesPoolExecutionStrategy
* ThreadExecutionStrategy
* ProcessExecutionStrategy
* Class-tree for bags
* Bag
* ErrorBag
* InheritingBag
*
* Co-routines: for unordered, or even ordered but long io.
* "context processors": replace initialize/finalize by a generator that yields only once
* "execute" function:
.. code-block:: python
def execute(graph: Graph, *, strategy: ExecutionStrategy, plugins: List[Plugin]) -> Execution:
pass
Version 0.2
:::::::::::
* Changelog
* Migration guide
* Update documentation
* Threaded does not terminate anymore
* More tests
Configuration
:::::::::::::
* Support for position arguments (options), required options are good candidates.
Context processors
::::::::::::::::::
* Be careful with order, especially with python 3.5.
* @contextual decorator is not clean enough. Once the behavior is right, find a way to use regular inheritance, without meta.
* ValueHolder API not clean. Find a better way.

View File

@ -20,19 +20,55 @@
limitations under the License. limitations under the License.
""" """
import sys import sys
import warnings
assert (sys.version_info >= (3, 5)), 'Python 3.5+ is required to use Bonobo.' assert (sys.version_info >= (3, 5)), 'Python 3.5+ is required to use Bonobo.'
from ._version import __version__ from ._version import __version__
from .config import *
from .context import *
from .core import * from .core import *
from .io import CsvReader, CsvWriter, FileReader, FileWriter, JsonReader, JsonWriter from .io import *
from .util import * from .util import *
DEFAULT_STRATEGY = 'threadpool'
STRATEGIES = {
'naive': NaiveStrategy,
'processpool': ProcessPoolExecutorStrategy,
'threadpool': ThreadPoolExecutorStrategy,
}
def run(graph, *chain, strategy=None, plugins=None):
from bonobo.core.strategies.base import Strategy
if len(chain):
warnings.warn('DEPRECATED. You should pass a Graph instance instead of a chain.')
from bonobo import Graph
graph = Graph(graph, *chain)
if not isinstance(strategy, Strategy):
if strategy is None:
strategy = DEFAULT_STRATEGY
try:
strategy = STRATEGIES[strategy]
except KeyError as exc:
raise RuntimeError('Invalid strategy {}.'.format(repr(strategy))) from exc
strategy = strategy()
return strategy.execute(graph, plugins=plugins)
__all__ = [ __all__ = [
'Bag', 'Bag',
'Configurable',
'ContextProcessor',
'contextual',
'CsvReader', 'CsvReader',
'CsvWriter', 'CsvWriter',
'Configurable',
'FileReader', 'FileReader',
'FileWriter', 'FileWriter',
'Graph', 'Graph',
@ -51,9 +87,9 @@ __all__ = [
'log', 'log',
'noop', 'noop',
'pprint', 'pprint',
'run',
'service', 'service',
'tee', 'tee',
] ]
del warnings
del sys del sys

View File

@ -1,9 +0,0 @@
from bonobo import FileWriter, JsonWriter
to_file = FileWriter
to_json = JsonWriter
__all__ = [
'to_json',
'to_file',
]

View File

@ -33,6 +33,8 @@ class Configurable(metaclass=ConfigurableMeta):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__()
self.__options_values__ = {} self.__options_values__ = {}
missing = set() missing = set()

View File

@ -0,0 +1,6 @@
from bonobo.context.processors import contextual, ContextProcessor
__all__ = [
'ContextProcessor',
'contextual',
]

View File

@ -3,39 +3,31 @@ from functools import partial
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from bonobo.context.processors import get_context_processors
from bonobo.core.bags import Bag, INHERIT_INPUT, ErrorBag from bonobo.core.bags import Bag, INHERIT_INPUT, ErrorBag
from bonobo.core.errors import InactiveReadableError from bonobo.core.errors import InactiveReadableError
from bonobo.core.inputs import Input from bonobo.core.inputs import Input
from bonobo.core.stats import WithStatistics from bonobo.core.statistics import WithStatistics
from bonobo.util.lifecycle import get_initializer, get_finalizer from bonobo.util.objects import Wrapper
from bonobo.util.tokens import BEGIN, END, NEW, RUNNING, TERMINATED, NOT_MODIFIED from bonobo.util.tokens import BEGIN, END, NOT_MODIFIED
def get_name(mixed): class GraphExecutionContext:
try: @property
return mixed.__name__ def started(self):
except AttributeError: return any(node.started for node in self.nodes)
return type(mixed).__name__
@property
def stopped(self):
return all(node.started and node.stopped for node in self.nodes)
def create_component_context(component, parent): @property
try: def alive(self):
CustomComponentContext = component.Context return self.started and not self.stopped
except AttributeError:
return ComponentExecutionContext(component, parent=parent)
if ComponentExecutionContext in CustomComponentContext.__mro__:
bases = (CustomComponentContext, )
else:
bases = (CustomComponentContext, ComponentExecutionContext)
return type(get_name(component).title() + 'ExecutionContext', bases, {})(component, parent=parent)
class ExecutionContext:
def __init__(self, graph, plugins=None): def __init__(self, graph, plugins=None):
self.graph = graph self.graph = graph
self.components = [create_component_context(component, parent=self) for component in self.graph.components] self.nodes = [NodeExecutionContext(node, parent=self) for node in self.graph.nodes]
self.plugins = [PluginExecutionContext(plugin, parent=self) for plugin in plugins or ()] self.plugins = [PluginExecutionContext(plugin, parent=self) for plugin in plugins or ()]
for i, component_context in enumerate(self): for i, component_context in enumerate(self):
@ -46,16 +38,16 @@ class ExecutionContext:
component_context.input.on_begin = partial(component_context.send, BEGIN, _control=True) component_context.input.on_begin = partial(component_context.send, BEGIN, _control=True)
component_context.input.on_end = partial(component_context.send, END, _control=True) component_context.input.on_end = partial(component_context.send, END, _control=True)
component_context.input.on_finalize = partial(component_context.finalize) component_context.input.on_finalize = partial(component_context.stop)
def __getitem__(self, item): def __getitem__(self, item):
return self.components[item] return self.nodes[item]
def __len__(self): def __len__(self):
return len(self.components) return len(self.nodes)
def __iter__(self): def __iter__(self):
yield from self.components yield from self.nodes
def recv(self, *messages): def recv(self, *messages):
"""Push a list of messages in the inputs of this graph's inputs, matching the output of special node "BEGIN" in """Push a list of messages in the inputs of this graph's inputs, matching the output of special node "BEGIN" in
@ -65,31 +57,67 @@ class ExecutionContext:
for message in messages: for message in messages:
self[i].recv(message) self[i].recv(message)
@property def start(self):
def alive(self): # todo use strategy
return any(component.alive for component in self.components) for node in self.nodes:
node.start()
def loop(self):
# todo use strategy
for node in self.nodes:
node.loop()
def stop(self):
# todo use strategy
for node in self.nodes:
node.stop()
class AbstractLoopContext: def ensure_tuple(tuple_or_mixed):
if isinstance(tuple_or_mixed, tuple):
return tuple_or_mixed
return (tuple_or_mixed, )
class LoopingExecutionContext(Wrapper):
alive = True alive = True
PERIOD = 0.25 PERIOD = 0.25
def __init__(self, wrapped): @property
self.wrapped = wrapped def state(self):
return self._started, self._stopped
def run(self): @property
self.initialize() def started(self):
self.loop() return self._started
self.finalize()
def initialize(self): @property
# pylint: disable=broad-except def stopped(self):
try: return self._stopped
initializer = get_initializer(self.wrapped)
except Exception as exc: def __init__(self, wrapped, parent):
self.handle_error(exc, traceback.format_exc()) super().__init__(wrapped)
else: self.parent = parent
return initializer(self) self._started, self._stopped, self._context, self._stack = False, False, None, []
def start(self):
assert self.state == (False, False), ('{}.start() can only be called on a new node.'
).format(type(self).__name__)
assert self._context is None
self._started = True
self._context = ()
for processor in get_context_processors(self.wrapped):
_processed = processor(self.wrapped, self, *self._context)
try:
# todo yield from ?
_append_to_context = next(_processed)
if _append_to_context is not None:
self._context += ensure_tuple(_append_to_context)
except Exception as exc: # pylint: disable=broad-except
self.handle_error(exc, traceback.format_exc())
raise
self._stack.append(_processed)
def loop(self): def loop(self):
"""Generic loop. A bit boring. """ """Generic loop. A bit boring. """
@ -103,15 +131,28 @@ class AbstractLoopContext:
""" """
raise NotImplementedError('Abstract.') raise NotImplementedError('Abstract.')
def finalize(self): def stop(self):
"""Generic finalizer. """ assert self._started, ('{}.stop() can only be called on a previously started node.').format(type(self).__name__)
# pylint: disable=broad-except if self._stopped:
try: return
finalizer = get_finalizer(self.wrapped)
except Exception as exc: assert self._context is not None
return self.handle_error(exc, traceback.format_exc())
else: self._stopped = True
return finalizer(self) while len(self._stack):
processor = self._stack.pop()
try:
# todo yield from ? how to ?
next(processor)
except StopIteration as exc:
# This is normal, and wanted.
pass
except Exception as exc: # pylint: disable=broad-except
self.handle_error(exc, traceback.format_exc())
raise
else:
# No error ? We should have had StopIteration ...
raise RuntimeError('Context processors should not yield more than once.')
def handle_error(self, exc, trace): def handle_error(self, exc, trace):
""" """
@ -129,11 +170,9 @@ class AbstractLoopContext:
print(trace) print(trace)
class PluginExecutionContext(AbstractLoopContext): class PluginExecutionContext(LoopingExecutionContext):
def __init__(self, plugin, parent): def __init__(self, wrapped, parent):
self.plugin = plugin LoopingExecutionContext.__init__(self, wrapped, parent)
self.parent = parent
super().__init__(self.plugin)
def shutdown(self): def shutdown(self):
self.alive = False self.alive = False
@ -145,7 +184,7 @@ class PluginExecutionContext(AbstractLoopContext):
self.handle_error(exc, traceback.format_exc()) self.handle_error(exc, traceback.format_exc())
class ComponentExecutionContext(WithStatistics, AbstractLoopContext): class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
""" """
todo: make the counter dependant of parent context? todo: make the counter dependant of parent context?
""" """
@ -153,47 +192,20 @@ class ComponentExecutionContext(WithStatistics, AbstractLoopContext):
@property @property
def alive(self): def alive(self):
"""todo check if this is right, and where it is used""" """todo check if this is right, and where it is used"""
return self.input.alive return self.input.alive and self._started and not self._stopped
@property def __init__(self, wrapped, parent):
def name(self): LoopingExecutionContext.__init__(self, wrapped, parent)
return getattr(self.component, '__name__', getattr(type(self.component), '__name__', repr(self.component))) WithStatistics.__init__(self, 'in', 'out', 'err')
def __init__(self, component, parent):
self.parent = parent
self.component = component
self.input = Input() self.input = Input()
self.outputs = [] self.outputs = []
self.state = NEW
self.stats = {
'in': 0,
'out': 0,
'err': 0,
'read': 0,
'write': 0,
}
super().__init__(self.component) def __str__(self):
return (('+' if self.alive else '-') + ' ' + self.__name__ + ' ' + self.get_statistics_as_string()).strip()
def __repr__(self): def __repr__(self):
"""Adds "alive" information to the transform representation.""" return '<' + self.__str__() + '>'
return ('+' if self.alive else '-') + ' ' + self.name + ' ' + self.get_stats_as_string()
def get_stats(self, *args, **kwargs):
return (
(
'in',
self.stats['in'],
),
(
'out',
self.stats['out'],
),
(
'err',
self.stats['err'],
),
)
def recv(self, *messages): def recv(self, *messages):
""" """
@ -212,7 +224,7 @@ class ComponentExecutionContext(WithStatistics, AbstractLoopContext):
:param _control: if true, won't count in statistics. :param _control: if true, won't count in statistics.
""" """
if not _control: if not _control:
self.stats['out'] += 1 self.increment('out')
for output in self.outputs: for output in self.outputs:
output.put(value) output.put(value)
@ -222,41 +234,9 @@ class ComponentExecutionContext(WithStatistics, AbstractLoopContext):
""" """
row = self.input.get(timeout=self.PERIOD) row = self.input.get(timeout=self.PERIOD)
self.stats['in'] += 1 self.increment('in')
return row return row
def apply_on(self, bag):
# todo add timer
if getattr(self.component, '_with_context', False):
return bag.apply(self.component, self)
return bag.apply(self.component)
def initialize(self):
# pylint: disable=broad-except
assert self.state is NEW, (
'A {} can only be run once, and thus is expected to be in {} state at '
'initialization time.'
).format(type(self).__name__, NEW)
self.state = RUNNING
try:
initializer_outputs = super().initialize()
self.handle(None, initializer_outputs)
except Exception as exc:
self.handle_error(exc, traceback.format_exc())
def finalize(self):
# pylint: disable=broad-except
assert self.state is RUNNING, ('A {} must be in {} state at finalization time.'
).format(type(self).__name__, RUNNING)
self.state = TERMINATED
try:
finalizer_outputs = super().finalize()
self.handle(None, finalizer_outputs)
except Exception as exc:
self.handle_error(exc, traceback.format_exc())
def loop(self): def loop(self):
while True: while True:
try: try:
@ -277,24 +257,21 @@ class ComponentExecutionContext(WithStatistics, AbstractLoopContext):
output channel.""" output channel."""
input_bag = self.get() input_bag = self.get()
outputs = self.apply_on(input_bag)
self.handle(input_bag, outputs)
def run(self): # todo add timer
self.initialize() self.handle_results(input_bag, input_bag.apply(self.wrapped, *self._context))
self.loop()
def handle(self, input_bag, outputs): def handle_results(self, input_bag, results):
# self._exec_time += timer.duration # self._exec_time += timer.duration
# Put data onto output channels # Put data onto output channels
try: try:
outputs = _iter(outputs) results = _iter(results)
except TypeError: # not an iterator except TypeError: # not an iterator
if outputs: if results:
if isinstance(outputs, ErrorBag): if isinstance(results, ErrorBag):
outputs.apply(self.handle_error) results.apply(self.handle_error)
else: else:
self.send(_resolve(input_bag, outputs)) self.send(_resolve(input_bag, results))
else: else:
# case with no result, an execution went through anyway, use for stats. # case with no result, an execution went through anyway, use for stats.
# self._exec_count += 1 # self._exec_count += 1
@ -302,7 +279,7 @@ class ComponentExecutionContext(WithStatistics, AbstractLoopContext):
else: else:
while True: # iterator while True: # iterator
try: try:
output = next(outputs) output = next(results)
except StopIteration: except StopIteration:
break break
else: else:

View File

@ -0,0 +1,45 @@
import types
_CONTEXT_PROCESSORS_ATTR = '__processors__'
def get_context_processors(mixed):
if isinstance(mixed, types.FunctionType):
yield from getattr(mixed, _CONTEXT_PROCESSORS_ATTR, ())
for cls in reversed((mixed if isinstance(mixed, type) else type(mixed)).__mro__):
yield from cls.__dict__.get(_CONTEXT_PROCESSORS_ATTR, ())
return ()
class ContextProcessor:
@property
def __name__(self):
return self.func.__name__
def __init__(self, func):
self.func = func
def __repr__(self):
return repr(self.func).replace('<function', '<{}'.format(type(self).__name__))
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def contextual(cls_or_func):
if isinstance(cls_or_func, types.FunctionType):
try:
getattr(cls_or_func, _CONTEXT_PROCESSORS_ATTR)
except AttributeError:
setattr(cls_or_func, _CONTEXT_PROCESSORS_ATTR, [])
return cls_or_func
if not _CONTEXT_PROCESSORS_ATTR in cls_or_func.__dict__:
setattr(cls_or_func, _CONTEXT_PROCESSORS_ATTR, [])
_processors = getattr(cls_or_func, _CONTEXT_PROCESSORS_ATTR)
for name, value in cls_or_func.__dict__.items():
if isinstance(value, ContextProcessor):
_processors.append(value)
return cls_or_func

View File

@ -54,7 +54,6 @@ class Bag:
return generator() return generator()
except TypeError as exc: except TypeError as exc:
print('nop')
raise TypeError('Could not apply bag to {}.'.format(func_or_iter)) from exc raise TypeError('Could not apply bag to {}.'.format(func_or_iter)) from exc
raise TypeError('Could not apply bag to {}.'.format(func_or_iter)) raise TypeError('Could not apply bag to {}.'.format(func_or_iter))

View File

@ -3,11 +3,11 @@ from bonobo.util.tokens import BEGIN
class Graph: class Graph:
""" """
Represents a coherent directed acyclic graph (DAG) of components. Represents a coherent directed acyclic graph of components.
""" """
def __init__(self, *chain): def __init__(self, *chain):
self.components = [] self.nodes = []
self.graph = {BEGIN: set()} self.graph = {BEGIN: set()}
self.add_chain(*chain) self.add_chain(*chain)
@ -16,13 +16,13 @@ class Graph:
self.graph[idx] = set() self.graph[idx] = set()
return self.graph[idx] return self.graph[idx]
def add_component(self, c): def add_node(self, c):
i = len(self.components) i = len(self.nodes)
self.components.append(c) self.nodes.append(c)
return i return i
def add_chain(self, *components, _input=BEGIN): def add_chain(self, *nodes, _input=BEGIN):
for component in components: for node in nodes:
_next = self.add_component(component) _next = self.add_node(node)
self.outputs_of(_input, create=True).add(_next) self.outputs_of(_input, create=True).add(_next)
_input = _next _input = _next

View File

@ -1,7 +1,5 @@
import functools import functools
import itertools import itertools
from functools import partial
class service: class service:
@ -23,7 +21,7 @@ class service:
return item return item
def define(self, *args, **kwargs): def define(self, *args, **kwargs):
new_service = type(self)(partial(self.factory, *args, **kwargs)) new_service = type(self)(functools.partial(self.factory, *args, **kwargs))
self.children.add(new_service) self.children.add(new_service)
return new_service return new_service

View File

@ -14,15 +14,17 @@
# see the license for the specific language governing permissions and # see the license for the specific language governing permissions and
# limitations under the license. # limitations under the license.
from abc import ABCMeta, abstractmethod
from bonobo.core.errors import AbstractError class WithStatistics:
def __init__(self, *names):
self.statistics_names = names
self.statistics = {name: 0 for name in names}
def get_statistics(self, *args, **kwargs):
return ((name, self.statistics[name]) for name in self.statistics_names)
class WithStatistics(metaclass=ABCMeta): def get_statistics_as_string(self, *args, **kwargs):
@abstractmethod return ' '.join(('{0}={1}'.format(name, cnt) for name, cnt in self.get_statistics(*args, **kwargs) if cnt > 0))
def get_stats(self, *args, **kwargs):
raise AbstractError(self.get_stats)
def get_stats_as_string(self, *args, **kwargs): def increment(self, name):
return ' '.join(('{0}={1}'.format(name, cnt) for name, cnt in self.get_stats(*args, **kwargs) if cnt > 0)) self.statistics[name] += 1

View File

@ -1,4 +1,4 @@
from bonobo.core.contexts import ExecutionContext from bonobo.context.execution import GraphExecutionContext
class Strategy: class Strategy:
@ -6,10 +6,10 @@ class Strategy:
Base class for execution strategies. Base class for execution strategies.
""" """
context_type = ExecutionContext graph_execution_context_factory = GraphExecutionContext
def create_context(self, graph, *args, **kwargs): def create_graph_execution_context(self, graph, *args, **kwargs):
return self.context_type(graph, *args, **kwargs) return self.graph_execution_context_factory(graph, *args, **kwargs)
def execute(self, graph, *args, **kwargs): def execute(self, graph, *args, **kwargs):
raise NotImplementedError raise NotImplementedError

View File

@ -17,19 +17,34 @@ class ExecutorStrategy(Strategy):
executor_factory = Executor executor_factory = Executor
def create_executor(self):
return self.executor_factory()
def execute(self, graph, *args, plugins=None, **kwargs): def execute(self, graph, *args, plugins=None, **kwargs):
context = self.create_context(graph, plugins=plugins) context = self.create_graph_execution_context(graph, plugins=plugins)
context.recv(BEGIN, Bag(), END) context.recv(BEGIN, Bag(), END)
executor = self.executor_factory() executor = self.create_executor()
futures = [] futures = []
for plugin_context in context.plugins: for plugin_context in context.plugins:
futures.append(executor.submit(plugin_context.run))
for component_context in context.components: def _runner(plugin_context=plugin_context):
futures.append(executor.submit(component_context.run)) plugin_context.start()
plugin_context.loop()
plugin_context.stop()
futures.append(executor.submit(_runner))
for node_context in context.nodes:
def _runner(node_context=node_context):
node_context.start()
node_context.loop()
futures.append(executor.submit(_runner))
while context.alive: while context.alive:
time.sleep(0.2) time.sleep(0.2)
@ -52,7 +67,7 @@ class ProcessPoolExecutorStrategy(ExecutorStrategy):
class ThreadCollectionStrategy(Strategy): class ThreadCollectionStrategy(Strategy):
def execute(self, graph, *args, plugins=None, **kwargs): def execute(self, graph, *args, plugins=None, **kwargs):
context = self.create_context(graph, plugins=plugins) context = self.create_graph_execution_context(graph, plugins=plugins)
context.recv(BEGIN, Bag(), END) context.recv(BEGIN, Bag(), END)
threads = [] threads = []

View File

@ -6,12 +6,12 @@ from ..bags import Bag
class NaiveStrategy(Strategy): class NaiveStrategy(Strategy):
def execute(self, graph, *args, plugins=None, **kwargs): def execute(self, graph, *args, plugins=None, **kwargs):
context = self.create_context(graph, plugins=plugins) context = self.create_graph_execution_context(graph, plugins=plugins)
context.recv(BEGIN, Bag(), END) context.recv(BEGIN, Bag(), END)
# TODO: how to run plugins in "naive" mode ? # TODO: how to run plugins in "naive" mode ?
context.start()
for component in context.components: context.loop()
component.run() context.stop()
return context return context

View File

@ -86,7 +86,7 @@ class ConsoleOutputPlugin(Plugin):
' ', ' ',
component.name, component.name,
' ', ' ',
component.get_stats_as_string( component.get_statistics_as_string(
debug=debug, profile=profile debug=debug, profile=profile
), ),
' ', ' ',
@ -100,7 +100,7 @@ class ConsoleOutputPlugin(Plugin):
' - ', ' - ',
component.name, component.name,
' ', ' ',
component.get_stats_as_string( component.get_statistics_as_string(
debug=debug, profile=profile debug=debug, profile=profile
), ),
' ', ' ',

View File

@ -30,7 +30,7 @@ def from_opendatasoft_api(
break break
for row in records: for row in records:
yield { ** row.get('fields', {}), 'geometry': row.get('geometry', {})} yield {**row.get('fields', {}), 'geometry': row.get('geometry', {})}
start += rows start += rows

View File

@ -1,18 +1,12 @@
import csv import csv
from copy import copy
from bonobo import Option, ContextProcessor, contextual
from bonobo.util.objects import ValueHolder
from .file import FileReader, FileWriter, FileHandler from .file import FileReader, FileWriter, FileHandler
class CsvHandler(FileHandler): class CsvHandler(FileHandler):
delimiter = ';'
quotechar = '"'
headers = None
class CsvReader(CsvHandler, FileReader):
""" """
Reads a CSV and yield the values as dicts.
.. attribute:: delimiter .. attribute:: delimiter
@ -26,30 +20,33 @@ class CsvReader(CsvHandler, FileReader):
The list of column names, if the CSV does not contain it as its first line. The list of column names, if the CSV does not contain it as its first line.
"""
delimiter = Option(str, default=';')
quotechar = Option(str, default='"')
headers = Option(tuple)
@contextual
class CsvReader(CsvHandler, FileReader):
"""
Reads a CSV and yield the values as dicts.
.. attribute:: skip .. attribute:: skip
The amount of lines to skip before it actually yield output. The amount of lines to skip before it actually yield output.
""" """
skip = 0 skip = Option(int, default=0)
def __init__(self, path_or_buf, delimiter=None, quotechar=None, headers=None, skip=None): @ContextProcessor
super().__init__(path_or_buf) def csv_headers(self, context, file):
yield ValueHolder(self.headers)
self.delimiter = str(delimiter or self.delimiter) def read(self, file, headers):
self.quotechar = quotechar or self.quotechar reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
self.headers = headers or self.headers headers.value = headers.value or next(reader)
self.skip = skip or self.skip field_count = len(headers.value)
@property
def has_headers(self):
return bool(self.headers)
def read(self, ctx):
reader = csv.reader(ctx.file, delimiter=self.delimiter, quotechar=self.quotechar)
headers = self.has_headers and self.headers or next(reader)
field_count = len(headers)
if self.skip and self.skip > 0: if self.skip and self.skip > 0:
for i in range(0, self.skip): for i in range(0, self.skip):
@ -62,30 +59,20 @@ class CsvReader(CsvHandler, FileReader):
field_count, field_count,
)) ))
yield dict(zip(headers, row)) yield dict(zip(headers.value, row))
@contextual
class CsvWriter(CsvHandler, FileWriter): class CsvWriter(CsvHandler, FileWriter):
def __init__(self, path_or_buf, delimiter=None, quotechar=None, headers=None): @ContextProcessor
super().__init__(path_or_buf) def writer(self, context, file, lineno):
writer = csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar)
headers = ValueHolder(list(self.headers) if self.headers else None)
yield writer, headers
self.delimiter = str(delimiter or self.delimiter) def write(self, file, lineno, writer, headers, row):
self.quotechar = quotechar or self.quotechar if not lineno.value:
self.headers = headers or self.headers headers.value = headers.value or row.keys()
writer.writerow(headers.value)
def initialize(self, ctx): writer.writerow(row[header] for header in headers.value)
super().initialize(ctx) lineno.value += 1
ctx.writer = csv.writer(ctx.file, delimiter=self.delimiter, quotechar=self.quotechar)
ctx.headers = copy(self.headers)
ctx.first = True
def write(self, ctx, row):
if ctx.first:
ctx.headers = ctx.headers or row.keys()
ctx.writer.writerow(ctx.headers)
ctx.first = False
ctx.writer.writerow(row[header] for header in ctx.headers)
def finalize(self, ctx):
del ctx.headers, ctx.writer, ctx.first
super().finalize(ctx)

View File

@ -1,63 +1,47 @@
from bonobo.util.lifecycle import with_context from bonobo.config import Configurable, Option
from bonobo.context import ContextProcessor
from bonobo.context.processors import contextual
from bonobo.util.objects import ValueHolder
__all__ = [ __all__ = [
'FileHandler',
'FileReader', 'FileReader',
'FileWriter', 'FileWriter',
] ]
@with_context @contextual
class FileHandler: class FileHandler(Configurable):
""" """
Abstract component factory for file-related components. Abstract component factory for file-related components.
""" """
eol = '\n' path = Option(str, required=True)
mode = None eol = Option(str, default='\n')
mode = Option(str)
def __init__(self, path_or_buf, eol=None): @ContextProcessor
self.path_or_buf = path_or_buf def file(self, context):
self.eol = eol or self.eol with self.open() as file:
yield file
def open(self): def open(self):
return open(self.path_or_buf, self.mode) return open(self.path, self.mode)
def close(self, fp):
"""
:param file fp:
"""
fp.close()
def initialize(self, ctx):
"""
Initialize a
:param ctx:
:return:
"""
assert not hasattr(ctx, 'file'), 'A file pointer is already in the context... I do not know what to say...'
ctx.file = self.open()
def finalize(self, ctx):
self.close(ctx.file)
del ctx.file
class Reader(FileHandler): class Reader(FileHandler):
def __call__(self, ctx): def __call__(self, *args):
yield from self.read(ctx) yield from self.read(*args)
def read(self, ctx): def read(self, *args):
raise NotImplementedError('Abstract.') raise NotImplementedError('Abstract.')
class Writer(FileHandler): class Writer(FileHandler):
def __call__(self, ctx, row): def __call__(self, *args):
return self.write(ctx, row) return self.write(*args)
def write(self, ctx, row): def write(self, *args):
raise NotImplementedError('Abstract.') raise NotImplementedError('Abstract.')
@ -70,20 +54,21 @@ class FileReader(Reader):
""" """
mode = 'r' mode = Option(str, default='r')
def read(self, ctx): def read(self, file):
""" """
Write a row on the next line of file pointed by `ctx.file`. Write a row on the next line of given file.
Prefix is used for newlines. Prefix is used for newlines.
:param ctx: :param ctx:
:param row: :param row:
""" """
for line in ctx.file: for line in file:
yield line.rstrip(self.eol) yield line.rstrip(self.eol)
@contextual
class FileWriter(Writer): class FileWriter(Writer):
""" """
Component factory for file or file-like writers. Component factory for file or file-like writers.
@ -93,13 +78,14 @@ class FileWriter(Writer):
""" """
mode = 'w+' mode = Option(str, default='w+')
def initialize(self, ctx): @ContextProcessor
ctx.line = 0 def lineno(self, context, file):
return super().initialize(ctx) lineno = ValueHolder(0, type=int)
yield lineno
def write(self, ctx, row): def write(self, file, lineno, row):
""" """
Write a row on the next line of opened file in context. Write a row on the next line of opened file in context.
@ -107,12 +93,8 @@ class FileWriter(Writer):
:param str row: :param str row:
:param str prefix: :param str prefix:
""" """
self._write_line(ctx.file, (self.eol if ctx.line else '') + row) self._write_line(file, (self.eol if lineno.value else '') + row)
ctx.line += 1 lineno.value += 1
def _write_line(self, fp, line): def _write_line(self, file, line):
return fp.write(line) return file.write(line)
def finalize(self, ctx):
del ctx.line
return super().finalize(ctx)

View File

@ -1,5 +1,6 @@
import json import json
from bonobo import ContextProcessor, contextual
from .file import FileWriter, FileReader from .file import FileWriter, FileReader
__all__ = ['JsonWriter', ] __all__ = ['JsonWriter', ]
@ -10,25 +11,24 @@ class JsonHandler:
class JsonReader(JsonHandler, FileReader): class JsonReader(JsonHandler, FileReader):
def read(self, ctx): def read(self, file):
for line in json.load(ctx.file): for line in json.load(file):
yield line yield line
@contextual
class JsonWriter(JsonHandler, FileWriter): class JsonWriter(JsonHandler, FileWriter):
def initialize(self, ctx): @ContextProcessor
super().initialize(ctx) def envelope(self, context, file, lineno):
ctx.file.write('[\n') file.write('[\n')
yield
file.write('\n]')
def write(self, ctx, row): def write(self, file, lineno, row):
""" """
Write a json row on the next line of file pointed by ctx.file. Write a json row on the next line of file pointed by ctx.file.
:param ctx: :param ctx:
:param row: :param row:
""" """
return super().write(ctx, json.dumps(row)) return super().write(file, lineno, json.dumps(row))
def finalize(self, ctx):
ctx.file.write('\n]')
super().finalize(ctx)

View File

@ -5,13 +5,10 @@ from pprint import pprint as _pprint
import blessings import blessings
from .helpers import run, console_run, jupyter_run from .helpers import console_run, jupyter_run
from .tokens import NOT_MODIFIED from .tokens import NOT_MODIFIED
from .options import Configurable, Option
__all__ = [ __all__ = [
'Configurable',
'Option',
'NOT_MODIFIED', 'NOT_MODIFIED',
'console_run', 'console_run',
'jupyter_run', 'jupyter_run',
@ -19,7 +16,6 @@ __all__ = [
'log', 'log',
'noop', 'noop',
'pprint', 'pprint',
'run',
'tee', 'tee',
] ]

View File

@ -1,25 +1,12 @@
def run(*chain, plugins=None, strategy=None):
from bonobo import Graph, ThreadPoolExecutorStrategy
if len(chain) == 1 and isinstance(chain[0], Graph):
graph = chain[0]
elif len(chain) >= 1:
graph = Graph()
graph.add_chain(*chain)
else:
raise RuntimeError('Empty chain.')
executor = (strategy or ThreadPoolExecutorStrategy)()
return executor.execute(graph, plugins=plugins or [])
def console_run(*chain, output=True, plugins=None, strategy=None): def console_run(*chain, output=True, plugins=None, strategy=None):
from bonobo import run
from bonobo.ext.console import ConsoleOutputPlugin from bonobo.ext.console import ConsoleOutputPlugin
return run(*chain, plugins=(plugins or []) + [ConsoleOutputPlugin()] if output else [], strategy=strategy) return run(*chain, plugins=(plugins or []) + [ConsoleOutputPlugin()] if output else [], strategy=strategy)
def jupyter_run(*chain, plugins=None, strategy=None): def jupyter_run(*chain, plugins=None, strategy=None):
from bonobo import run
from bonobo.ext.jupyter import JupyterOutputPlugin from bonobo.ext.jupyter import JupyterOutputPlugin
return run(*chain, plugins=(plugins or []) + [JupyterOutputPlugin()], strategy=strategy) return run(*chain, plugins=(plugins or []) + [JupyterOutputPlugin()], strategy=strategy)

22
bonobo/util/objects.py Normal file
View File

@ -0,0 +1,22 @@
def get_name(mixed):
try:
return mixed.__name__
except AttributeError:
return type(mixed).__name__
class Wrapper:
def __init__(self, wrapped):
self.wrapped = wrapped
@property
def __name__(self):
return getattr(self.wrapped, '__name__', getattr(type(self.wrapped), '__name__', repr(self.wrapped)))
name = __name__
class ValueHolder:
def __init__(self, value, *, type=None):
self.value = value
self.type = type

View File

@ -1,9 +1,9 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from bonobo.core.contexts import ComponentExecutionContext from bonobo.context.execution import NodeExecutionContext
class CapturingComponentExecutionContext(ComponentExecutionContext): class CapturingNodeExecutionContext(NodeExecutionContext):
def __init__(self, component, parent): def __init__(self, wrapped, parent):
super().__init__(component, parent) super().__init__(wrapped, parent)
self.send = MagicMock() self.send = MagicMock()

View File

@ -11,8 +11,4 @@ class Token:
BEGIN = Token('Begin') BEGIN = Token('Begin')
END = Token('End') END = Token('End')
NEW = Token('New')
RUNNING = Token('Running')
TERMINATED = Token('Terminated')
NOT_MODIFIED = Token('NotModified') NOT_MODIFIED = Token('NotModified')

View File

@ -3,8 +3,6 @@ from random import randint
from bonobo import Bag from bonobo import Bag
from bonobo.core.graphs import Graph from bonobo.core.graphs import Graph
from bonobo.core.strategies.executor import ThreadPoolExecutorStrategy
from bonobo.ext.console import ConsoleOutputPlugin
def extract(): def extract():
@ -23,11 +21,10 @@ def load(topic: str, title: str, wait: int):
print('{} ({}) wait={}'.format(title, topic, wait)) print('{} ({}) wait={}'.format(title, topic, wait))
Strategy = ThreadPoolExecutorStrategy graph = Graph()
graph.add_chain(extract, transform, load)
if __name__ == '__main__': if __name__ == '__main__':
etl = Graph() from bonobo.util.helpers import run
etl.add_chain(extract, transform, load)
s = Strategy() run(graph)
s.execute(etl, plugins=[ConsoleOutputPlugin()])

View File

@ -2,8 +2,6 @@ import time
from random import randint from random import randint
from bonobo.core.graphs import Graph from bonobo.core.graphs import Graph
from bonobo.core.strategies.executor import ThreadPoolExecutorStrategy
from bonobo.ext.console import ConsoleOutputPlugin
def extract(): def extract():
@ -25,11 +23,10 @@ def load(s):
print(s) print(s)
Strategy = ThreadPoolExecutorStrategy graph = Graph()
graph.add_chain(extract, transform, load)
if __name__ == '__main__': if __name__ == '__main__':
etl = Graph() from bonobo import run
etl.add_chain(extract, transform, load)
s = Strategy() run(graph)
s.execute(etl, plugins=[ConsoleOutputPlugin()])

20
examples/basics_file.py Normal file
View File

@ -0,0 +1,20 @@
from bonobo import FileReader, Graph
def skip_comments(line):
if not line.startswith('#'):
yield line
graph = Graph(
FileReader(path='/etc/passwd'),
skip_comments,
lambda s: s.split(':'),
lambda l: l[0],
print,
)
if __name__ == '__main__':
import bonobo
bonobo.run(graph)

View File

@ -0,0 +1,21 @@
import os
from bonobo import CsvReader, Graph
__path__ = os.path.dirname(__file__)
def skip_comments(line):
if not line.startswith('#'):
yield line
graph = Graph(
CsvReader(path=os.path.join(__path__, 'datasets/coffeeshops.txt')),
print,
)
if __name__ == '__main__':
import bonobo
bonobo.run(graph)

View File

@ -2,8 +2,6 @@ import time
from random import randint from random import randint
from bonobo.core.graphs import Graph from bonobo.core.graphs import Graph
from bonobo.core.strategies.executor import ThreadPoolExecutorStrategy
from bonobo.ext.console import ConsoleOutputPlugin
def extract(): def extract():
@ -22,11 +20,10 @@ def load(s):
print(s) print(s)
Strategy = ThreadPoolExecutorStrategy graph = Graph()
graph.add_chain(extract, transform, load)
if __name__ == '__main__': if __name__ == '__main__':
etl = Graph() from bonobo import run
etl.add_chain(extract, transform, load)
s = Strategy() run(graph)
s.execute(etl, plugins=[ConsoleOutputPlugin()])

View File

@ -1,8 +1,9 @@
import json import json
import os
from blessings import Terminal from blessings import Terminal
from bonobo import console_run, tee, JsonWriter, Graph from bonobo import tee, JsonWriter, Graph
from bonobo.ext.opendatasoft import from_opendatasoft_api from bonobo.ext.opendatasoft import from_opendatasoft_api
try: try:
@ -15,6 +16,7 @@ API_NETLOC = 'datanova.laposte.fr'
ROWS = 100 ROWS = 100
t = Terminal() t = Terminal()
__path__ = os.path.dirname(__file__)
def _getlink(x): def _getlink(x):
@ -57,13 +59,17 @@ def display(row):
graph = Graph( graph = Graph(
from_opendatasoft_api( from_opendatasoft_api(
API_DATASET, netloc=API_NETLOC, timezone='Europe/Paris' API_DATASET,
netloc=API_NETLOC,
timezone='Europe/Paris'
), ),
normalize, normalize,
filter_france, filter_france,
tee(display), tee(display),
JsonWriter('fablabs.json'), JsonWriter(path=os.path.join(__path__, 'datasets/coffeeshops.txt')),
) )
if __name__ == '__main__': if __name__ == '__main__':
console_run(graph, output=True) import bonobo
bonobo.run(graph)

View File

@ -0,0 +1,55 @@
from operator import attrgetter
from bonobo import contextual, ContextProcessor
from bonobo.context.processors import get_context_processors
@contextual
class CP1:
@ContextProcessor
def c(self):
pass
@ContextProcessor
def a(self):
pass
@ContextProcessor
def b(self):
pass
@contextual
class CP2(CP1):
@ContextProcessor
def f(self):
pass
@ContextProcessor
def e(self):
pass
@ContextProcessor
def d(self):
pass
@contextual
class CP3(CP2):
@ContextProcessor
def c(self):
pass
@ContextProcessor
def b(self):
pass
def get_all_processors_names(cls):
return list(map(attrgetter('__name__'), get_context_processors(cls)))
def test_inheritance_and_ordering():
assert get_all_processors_names(CP1) == ['c', 'a', 'b']
assert get_all_processors_names(CP2) == ['c', 'a', 'b', 'f', 'e', 'd']
assert get_all_processors_names(CP3) == ['c', 'a', 'b', 'f', 'e', 'd', 'c', 'b']

View File

@ -1,6 +1,5 @@
from bonobo import Graph, NaiveStrategy, Bag from bonobo import Graph, NaiveStrategy, Bag, contextual
from bonobo.core.contexts import ExecutionContext from bonobo.context.execution import GraphExecutionContext
from bonobo.util.lifecycle import with_context
from bonobo.util.tokens import BEGIN, END from bonobo.util.tokens import BEGIN, END
@ -12,11 +11,16 @@ def square(i: int) -> int:
return i**2 return i**2
@with_context @contextual
def push_result(ctx, i: int): def push_result(results, i: int):
if not hasattr(ctx.parent, 'results'): results.append(i)
ctx.parent.results = []
ctx.parent.results.append(i)
@push_result.__processors__.append
def results(f, context):
results = []
yield results
context.parent.results = results
chain = (generate_integers, square, push_result) chain = (generate_integers, square, push_result)
@ -25,8 +29,8 @@ chain = (generate_integers, square, push_result)
def test_empty_execution_context(): def test_empty_execution_context():
graph = Graph() graph = Graph()
ctx = ExecutionContext(graph) ctx = GraphExecutionContext(graph)
assert not len(ctx.components) assert not len(ctx.nodes)
assert not len(ctx.plugins) assert not len(ctx.plugins)
assert not ctx.alive assert not ctx.alive
@ -46,15 +50,19 @@ def test_simple_execution_context():
graph = Graph() graph = Graph()
graph.add_chain(*chain) graph.add_chain(*chain)
ctx = ExecutionContext(graph) ctx = GraphExecutionContext(graph)
assert len(ctx.components) == len(chain) assert len(ctx.nodes) == len(chain)
assert not len(ctx.plugins) assert not len(ctx.plugins)
for i, component in enumerate(chain): for i, node in enumerate(chain):
assert ctx[i].component is component assert ctx[i].wrapped is node
assert not ctx.alive assert not ctx.alive
ctx.recv(BEGIN, Bag(), END) ctx.recv(BEGIN, Bag(), END)
assert not ctx.alive
ctx.start()
assert ctx.alive assert ctx.alive

View File

@ -24,20 +24,20 @@ def test_graph_outputs_of():
def test_graph_add_component(): def test_graph_add_component():
g = Graph() g = Graph()
assert len(g.components) == 0 assert len(g.nodes) == 0
g.add_component(identity) g.add_node(identity)
assert len(g.components) == 1 assert len(g.nodes) == 1
g.add_component(identity) g.add_node(identity)
assert len(g.components) == 2 assert len(g.nodes) == 2
def test_graph_add_chain(): def test_graph_add_chain():
g = Graph() g = Graph()
assert len(g.components) == 0 assert len(g.nodes) == 0
g.add_chain(identity, identity, identity) g.add_chain(identity, identity, identity)
assert len(g.components) == 3 assert len(g.nodes) == 3
assert len(g.outputs_of(BEGIN)) == 1 assert len(g.outputs_of(BEGIN)) == 1

View File

@ -1,8 +1,8 @@
from bonobo.core.stats import WithStatistics from bonobo.core.statistics import WithStatistics
class MyThingWithStats(WithStatistics): class MyThingWithStats(WithStatistics):
def get_stats(self, *args, **kwargs): def get_statistics(self, *args, **kwargs):
return ( return (
('foo', 42), ('foo', 42),
('bar', 69), ('bar', 69),
@ -11,4 +11,4 @@ class MyThingWithStats(WithStatistics):
def test_with_statistics(): def test_with_statistics():
o = MyThingWithStats() o = MyThingWithStats()
assert o.get_stats_as_string() == 'foo=42 bar=69' assert o.get_statistics_as_string() == 'foo=42 bar=69'

View File

@ -1,21 +1,22 @@
import pytest import pytest
from bonobo import Bag, CsvReader, CsvWriter from bonobo import Bag, CsvReader, CsvWriter
from bonobo.core.contexts import ComponentExecutionContext from bonobo.context.execution import NodeExecutionContext
from bonobo.util.testing import CapturingComponentExecutionContext from bonobo.util.testing import CapturingNodeExecutionContext
from bonobo.util.tokens import BEGIN, END from bonobo.util.tokens import BEGIN, END
def test_write_csv_to_file(tmpdir): def test_write_csv_to_file(tmpdir):
file = tmpdir.join('output.json') file = tmpdir.join('output.json')
writer = CsvWriter(str(file)) writer = CsvWriter(path=str(file))
context = ComponentExecutionContext(writer, None) context = NodeExecutionContext(writer, None)
context.initialize()
context.recv(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END) context.recv(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END)
context.start()
context.step() context.step()
context.step() context.step()
context.finalize() context.stop()
assert file.read() == 'foo\nbar\nbaz\n' assert file.read() == 'foo\nbar\nbaz\n'
@ -23,27 +24,18 @@ def test_write_csv_to_file(tmpdir):
getattr(context, 'file') getattr(context, 'file')
def test_write_json_without_initializer_should_not_work(tmpdir):
file = tmpdir.join('output.json')
writer = CsvWriter(str(file))
context = ComponentExecutionContext(writer, None)
with pytest.raises(AttributeError):
writer(context, {'foo': 'bar'})
def test_read_csv_from_file(tmpdir): def test_read_csv_from_file(tmpdir):
file = tmpdir.join('input.csv') file = tmpdir.join('input.csv')
file.write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar') file.write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar')
reader = CsvReader(str(file), delimiter=',') reader = CsvReader(path=str(file), delimiter=',')
context = CapturingComponentExecutionContext(reader, None) context = CapturingNodeExecutionContext(reader, None)
context.initialize() context.start()
context.recv(BEGIN, Bag(), END) context.recv(BEGIN, Bag(), END)
context.step() context.step()
context.finalize() context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2

View File

@ -1,8 +1,8 @@
import pytest import pytest
from bonobo import FileWriter, Bag, FileReader from bonobo import FileWriter, Bag, FileReader
from bonobo.core.contexts import ComponentExecutionContext from bonobo.context.execution import NodeExecutionContext
from bonobo.util.testing import CapturingComponentExecutionContext from bonobo.util.testing import CapturingNodeExecutionContext
from bonobo.util.tokens import BEGIN, END from bonobo.util.tokens import BEGIN, END
@ -16,27 +16,24 @@ from bonobo.util.tokens import BEGIN, END
def test_file_writer_in_context(tmpdir, lines, output): def test_file_writer_in_context(tmpdir, lines, output):
file = tmpdir.join('output.txt') file = tmpdir.join('output.txt')
writer = FileWriter(str(file)) writer = FileWriter(path=str(file))
context = ComponentExecutionContext(writer, None) context = NodeExecutionContext(writer, None)
context.initialize() context.start()
context.recv(BEGIN, *map(Bag, lines), END) context.recv(BEGIN, *map(Bag, lines), END)
for i in range(len(lines)): for i in range(len(lines)):
context.step() context.step()
context.finalize() context.stop()
assert file.read() == output assert file.read() == output
with pytest.raises(AttributeError):
getattr(context, 'file')
def test_file_writer_out_of_context(tmpdir): def test_file_writer_out_of_context(tmpdir):
file = tmpdir.join('output.txt') file = tmpdir.join('output.txt')
writer = FileWriter(str(file)) writer = FileWriter(path=str(file))
fp = writer.open()
fp.write('Yosh!') with writer.open() as fp:
writer.close(fp) fp.write('Yosh!')
assert file.read() == 'Yosh!' assert file.read() == 'Yosh!'
@ -45,13 +42,13 @@ def test_file_reader_in_context(tmpdir):
file = tmpdir.join('input.txt') file = tmpdir.join('input.txt')
file.write('Hello\nWorld\n') file.write('Hello\nWorld\n')
reader = FileReader(str(file)) reader = FileReader(path=str(file))
context = CapturingComponentExecutionContext(reader, None) context = CapturingNodeExecutionContext(reader, None)
context.initialize() context.start()
context.recv(BEGIN, Bag(), END) context.recv(BEGIN, Bag(), END)
context.step() context.step()
context.finalize() context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2

View File

@ -1,20 +1,21 @@
import pytest import pytest
from bonobo import Bag, JsonWriter, JsonReader from bonobo import Bag, JsonWriter, JsonReader
from bonobo.core.contexts import ComponentExecutionContext from bonobo.context.execution import NodeExecutionContext
from bonobo.util.testing import CapturingComponentExecutionContext from bonobo.util.objects import ValueHolder
from bonobo.util.testing import CapturingNodeExecutionContext
from bonobo.util.tokens import BEGIN, END from bonobo.util.tokens import BEGIN, END
def test_write_json_to_file(tmpdir): def test_write_json_to_file(tmpdir):
file = tmpdir.join('output.json') file = tmpdir.join('output.json')
writer = JsonWriter(str(file)) writer = JsonWriter(path=str(file))
context = ComponentExecutionContext(writer, None) context = NodeExecutionContext(writer, None)
context.initialize() context.start()
context.recv(BEGIN, Bag({'foo': 'bar'}), END) context.recv(BEGIN, Bag({'foo': 'bar'}), END)
context.step() context.step()
context.finalize() context.stop()
assert file.read() == '[\n{"foo": "bar"}\n]' assert file.read() == '[\n{"foo": "bar"}\n]'
@ -25,26 +26,17 @@ def test_write_json_to_file(tmpdir):
getattr(context, 'first') getattr(context, 'first')
def test_write_json_without_initializer_should_not_work(tmpdir):
file = tmpdir.join('output.json')
writer = JsonWriter(str(file))
context = ComponentExecutionContext(writer, None)
with pytest.raises(AttributeError):
writer(context, {'foo': 'bar'})
def test_read_json_from_file(tmpdir): def test_read_json_from_file(tmpdir):
file = tmpdir.join('input.json') file = tmpdir.join('input.json')
file.write('[{"x": "foo"},{"x": "bar"}]') file.write('[{"x": "foo"},{"x": "bar"}]')
reader = JsonReader(str(file)) reader = JsonReader(path=str(file))
context = CapturingComponentExecutionContext(reader, None) context = CapturingNodeExecutionContext(reader, None)
context.initialize() context.start()
context.recv(BEGIN, Bag(), END) context.recv(BEGIN, Bag(), END)
context.step() context.step()
context.finalize() context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2