diff --git a/bonobo/commands/convert.py b/bonobo/commands/convert.py index 2d13ab4..48caaa3 100644 --- a/bonobo/commands/convert.py +++ b/bonobo/commands/convert.py @@ -120,7 +120,8 @@ def register(parser): parser.add_argument( '--' + WRITER, '-w', - help='Choose the writer factory if it cannot be detected from extension, or if detection is wrong (use - for console pretty print).' + help= + 'Choose the writer factory if it cannot be detected from extension, or if detection is wrong (use - for console pretty print).' ) parser.add_argument( '--filter', diff --git a/bonobo/config/options.py b/bonobo/config/options.py index cad4ca8..3462a31 100644 --- a/bonobo/config/options.py +++ b/bonobo/config/options.py @@ -66,10 +66,10 @@ class Option: self._creation_counter = Option._creation_counter Option._creation_counter += 1 - def __get__(self, inst, typ): + def __get__(self, inst, type_): # XXX If we call this on the type, then either return overriden value or ... ??? if inst is None: - return vars(type).get(self.name, self) + return vars(type_).get(self.name, self) if not self.name in inst._options_values: inst._options_values[self.name] = self.get_default() @@ -96,6 +96,24 @@ class Option: return self.default() if callable(self.default) else self.default +class RemovedOption(Option): + def __init__(self, *args, value, **kwargs): + kwargs['required'] = False + super(RemovedOption, self).__init__(*args, **kwargs) + self.value = value + + def clean(self, value): + if value != self.value: + raise ValueError( + 'Removed options cannot change value, {!r} must now be {!r} (and you should remove setting the value explicitely, as it is deprecated and will be removed quite soon.'. + format(self.name, self.value) + ) + return self.value + + def get_default(self): + return self.value + + class Method(Option): """ A Method is a special callable-valued option, that can be used in three different ways (but for same purpose). diff --git a/bonobo/config/processors.py b/bonobo/config/processors.py index 27f8703..73c9949 100644 --- a/bonobo/config/processors.py +++ b/bonobo/config/processors.py @@ -1,9 +1,8 @@ from collections import Iterable from contextlib import contextmanager -from bonobo.config.options import Option -from bonobo.util.compat import deprecated_alias -from bonobo.util.iterators import ensure_tuple +from bonobo.config import Option +from bonobo.util import deprecated_alias, ensure_tuple _CONTEXT_PROCESSORS_ATTR = '__processors__' @@ -24,7 +23,7 @@ class ContextProcessor(Option): Example: >>> from bonobo.config import Configurable - >>> from bonobo.util.objects import ValueHolder + >>> from bonobo.util import ValueHolder >>> class Counter(Configurable): ... @ContextProcessor @@ -91,7 +90,10 @@ class ContextCurrifier: self._stack, self._stack_values = list(), list() for processor in resolve_processors(self.wrapped): _processed = processor(self.wrapped, *context, *self.context) - _append_to_context = next(_processed) + try: + _append_to_context = next(_processed) + except TypeError as exc: + raise TypeError('Context processor should be generators (using yield).') from exc self._stack_values.append(_append_to_context) if _append_to_context is not None: self.context += ensure_tuple(_append_to_context) diff --git a/bonobo/execution/base.py b/bonobo/execution/base.py index abb3516..81ac74e 100644 --- a/bonobo/execution/base.py +++ b/bonobo/execution/base.py @@ -5,6 +5,7 @@ from time import sleep from bonobo.config import create_container from bonobo.config.processors import ContextCurrifier from bonobo.plugins import get_enhancers +from bonobo.util import inspect_node, isconfigurabletype from bonobo.util.errors import print_error from bonobo.util.objects import Wrapper, get_name @@ -72,6 +73,15 @@ class LoopingExecutionContext(Wrapper): self._started = True self._stack = ContextCurrifier(self.wrapped, *self._get_initial_context()) + if isconfigurabletype(self.wrapped): + # Not normal to have a partially configured object here, so let's warn the user instead of having get into + # the hard trouble of understanding that by himself. + raise TypeError( + 'The Configurable should be fully instanciated by now, unfortunately I got a PartiallyConfigured object...' + ) + # XXX enhance that, maybe giving hints on what's missing. + # print(inspect_node(self.wrapped)) + self._stack.setup(self) for enhancer in self._enhancers: diff --git a/bonobo/execution/graph.py b/bonobo/execution/graph.py index 91e4aef..77e01fa 100644 --- a/bonobo/execution/graph.py +++ b/bonobo/execution/graph.py @@ -1,3 +1,4 @@ +import time from functools import partial from bonobo.config import create_container @@ -7,6 +8,9 @@ from bonobo.execution.plugin import PluginExecutionContext class GraphExecutionContext: + NodeExecutionContextType = NodeExecutionContext + PluginExecutionContextType = PluginExecutionContext + @property def started(self): return any(node.started for node in self.nodes) @@ -21,15 +25,17 @@ class GraphExecutionContext: def __init__(self, graph, plugins=None, services=None): self.graph = graph - self.nodes = [NodeExecutionContext(node, parent=self) for node in self.graph] - self.plugins = [PluginExecutionContext(plugin, parent=self) for plugin in plugins or ()] + self.nodes = [self.create_node_execution_context_for(node) for node in self.graph] + self.plugins = [self.create_plugin_execution_context_for(plugin) for plugin in plugins or ()] self.services = create_container(services) # Probably not a good idea to use it unless you really know what you're doing. But you can access the context. self.services['__graph_context'] = self for i, node_context in enumerate(self): - node_context.outputs = [self[j].input for j in self.graph.outputs_of(i)] + outputs = self.graph.outputs_of(i) + if len(outputs): + node_context.outputs = [self[j].input for j in outputs] node_context.input.on_begin = partial(node_context.send, BEGIN, _control=True) node_context.input.on_end = partial(node_context.send, END, _control=True) node_context.input.on_finalize = partial(node_context.stop) @@ -43,6 +49,12 @@ class GraphExecutionContext: def __iter__(self): yield from self.nodes + def create_node_execution_context_for(self, node): + return self.NodeExecutionContextType(node, parent=self) + + def create_plugin_execution_context_for(self, plugin): + return self.PluginExecutionContextType(plugin, parent=self) + def write(self, *messages): """Push a list of messages in the inputs of this graph's inputs, matching the output of special node "BEGIN" in our graph.""" @@ -51,17 +63,23 @@ class GraphExecutionContext: for message in messages: self[i].write(message) - def start(self): - # todo use strategy + def start(self, starter=None): for node in self.nodes: - node.start() + if starter is None: + node.start() + else: + starter(node) - def stop(self): - # todo use strategy - for node in self.nodes: - node.stop() + def start_plugins(self, starter=None): + for plugin in self.plugins: + if starter is None: + plugin.start() + else: + starter(plugin) - def loop(self): - # todo use strategy + def stop(self, stopper=None): for node in self.nodes: - node.loop() + if stopper is None: + node.stop() + else: + stopper(node) diff --git a/bonobo/execution/node.py b/bonobo/execution/node.py index e8869ac..2aa626c 100644 --- a/bonobo/execution/node.py +++ b/bonobo/execution/node.py @@ -2,15 +2,15 @@ import traceback from queue import Empty from time import sleep -from bonobo.constants import INHERIT_INPUT, NOT_MODIFIED +from bonobo import settings +from bonobo.constants import INHERIT_INPUT, NOT_MODIFIED, BEGIN, END from bonobo.errors import InactiveReadableError, UnrecoverableError from bonobo.execution.base import LoopingExecutionContext from bonobo.structs.bags import Bag from bonobo.structs.inputs import Input +from bonobo.util import get_name, iserrorbag, isloopbackbag, isdict, istuple from bonobo.util.compat import deprecated_alias -from bonobo.util.inspect import iserrorbag, isloopbackbag from bonobo.util.iterators import iter_if_not_sequence -from bonobo.util.objects import get_name from bonobo.util.statistics import WithStatistics @@ -28,12 +28,12 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext): def alive_str(self): return '+' if self.alive else '-' - def __init__(self, wrapped, parent=None, services=None): + def __init__(self, wrapped, parent=None, services=None, _input=None, _outputs=None): LoopingExecutionContext.__init__(self, wrapped, parent=parent, services=services) WithStatistics.__init__(self, 'in', 'out', 'err') - self.input = Input() - self.outputs = [] + self.input = _input or Input() + self.outputs = _outputs or [] def __str__(self): return self.alive_str + ' ' + self.__name__ + self.get_statistics_as_string(prefix=' ') @@ -51,6 +51,11 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext): for message in messages: self.input.put(message) + def write_sync(self, *messages): + self.write(BEGIN, *messages, END) + for _ in messages: + self.step() + # XXX deprecated alias recv = deprecated_alias('recv', write) @@ -143,12 +148,18 @@ def _resolve(input_bag, output): return output # If it does not look like a bag, let's create one for easier manipulation - if hasattr(output, 'apply'): + if hasattr(output, 'apply'): # XXX TODO use isbag() ? # Already a bag? Check if we need to set parent. if INHERIT_INPUT in output.flags: output.set_parent(input_bag) - else: - # Not a bag? Let's encapsulate it. - output = Bag(output) + return output - return output + # If we're using kwargs ioformat, then a dict means kwargs. + if settings.IOFORMAT == settings.IOFORMAT_KWARGS and isdict(output): + return Bag(**output) + + if istuple(output): + return Bag(*output) + + # Either we use arg0 format, either it's "just" a value. + return Bag(output) diff --git a/bonobo/nodes/io/base.py b/bonobo/nodes/io/base.py index 3cecb70..af9e609 100644 --- a/bonobo/nodes/io/base.py +++ b/bonobo/nodes/io/base.py @@ -1,39 +1,4 @@ -from bonobo import settings from bonobo.config import Configurable, ContextProcessor, Option, Service -from bonobo.errors import UnrecoverableValueError, UnrecoverableNotImplementedError -from bonobo.structs.bags import Bag - - -class IOFormatEnabled(Configurable): - ioformat = Option(default=settings.IOFORMAT.get) - - def get_input(self, *args, **kwargs): - if self.ioformat == settings.IOFORMAT_ARG0: - if len(args) != 1 or len(kwargs): - raise UnrecoverableValueError( - 'Wrong input formating: IOFORMAT=ARG0 implies one arg and no kwargs, got args={!r} and kwargs={!r}.'. - format(args, kwargs) - ) - return args[0] - - if self.ioformat == settings.IOFORMAT_KWARGS: - if len(args) or not len(kwargs): - raise UnrecoverableValueError( - 'Wrong input formating: IOFORMAT=KWARGS ioformat implies no arg, got args={!r} and kwargs={!r}.'. - format(args, kwargs) - ) - return kwargs - - raise UnrecoverableNotImplementedError('Unsupported format.') - - def get_output(self, row): - if self.ioformat == settings.IOFORMAT_ARG0: - return row - - if self.ioformat == settings.IOFORMAT_KWARGS: - return Bag(**row) - - raise UnrecoverableNotImplementedError('Unsupported format.') class FileHandler(Configurable): diff --git a/bonobo/nodes/io/csv.py b/bonobo/nodes/io/csv.py index 75fffe8..c504c16 100644 --- a/bonobo/nodes/io/csv.py +++ b/bonobo/nodes/io/csv.py @@ -1,10 +1,11 @@ import csv from bonobo.config import Option +from bonobo.config.options import RemovedOption from bonobo.config.processors import ContextProcessor from bonobo.constants import NOT_MODIFIED +from bonobo.nodes.io.base import FileHandler from bonobo.nodes.io.file import FileReader, FileWriter -from bonobo.nodes.io.base import FileHandler, IOFormatEnabled from bonobo.util.objects import ValueHolder @@ -27,9 +28,10 @@ class CsvHandler(FileHandler): delimiter = Option(str, default=';') quotechar = Option(str, default='"') headers = Option(tuple, required=False) + ioformat = RemovedOption(positional=False, value='kwargs') -class CsvReader(IOFormatEnabled, FileReader, CsvHandler): +class CsvReader(FileReader, CsvHandler): """ Reads a CSV and yield the values as dicts. @@ -62,18 +64,17 @@ class CsvReader(IOFormatEnabled, FileReader, CsvHandler): if len(row) != field_count: raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count, )) - yield self.get_output(dict(zip(_headers, row))) + yield dict(zip(_headers, row)) -class CsvWriter(IOFormatEnabled, FileWriter, CsvHandler): +class CsvWriter(FileWriter, CsvHandler): @ContextProcessor def writer(self, context, fs, file, lineno): writer = csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol) headers = ValueHolder(list(self.headers) if self.headers else None) yield writer, headers - def write(self, fs, file, lineno, writer, headers, *args, **kwargs): - row = self.get_input(*args, **kwargs) + def write(self, fs, file, lineno, writer, headers, **row): if not lineno: headers.set(headers.value or row.keys()) writer.writerow(headers.get()) diff --git a/bonobo/nodes/io/json.py b/bonobo/nodes/io/json.py index f1c6df0..533d628 100644 --- a/bonobo/nodes/io/json.py +++ b/bonobo/nodes/io/json.py @@ -1,8 +1,9 @@ import json +from bonobo.config.options import RemovedOption from bonobo.config.processors import ContextProcessor from bonobo.constants import NOT_MODIFIED -from bonobo.nodes.io.base import FileHandler, IOFormatEnabled +from bonobo.nodes.io.base import FileHandler from bonobo.nodes.io.file import FileReader, FileWriter from bonobo.structs.bags import Bag @@ -10,14 +11,15 @@ from bonobo.structs.bags import Bag class JsonHandler(FileHandler): eol = ',\n' prefix, suffix = '[', ']' + ioformat = RemovedOption(positional=False, value='kwargs') -class JsonReader(IOFormatEnabled, FileReader, JsonHandler): +class JsonReader(FileReader, JsonHandler): loader = staticmethod(json.load) def read(self, fs, file): for line in self.loader(file): - yield self.get_output(line) + yield line class JsonDictItemsReader(JsonReader): @@ -26,21 +28,20 @@ class JsonDictItemsReader(JsonReader): yield Bag(*line) -class JsonWriter(IOFormatEnabled, FileWriter, JsonHandler): +class JsonWriter(FileWriter, JsonHandler): @ContextProcessor def envelope(self, context, fs, file, lineno): file.write(self.prefix) yield file.write(self.suffix) - def write(self, fs, file, lineno, *args, **kwargs): + def write(self, fs, file, lineno, **row): """ Write a json row on the next line of file pointed by ctx.file. :param ctx: :param row: """ - row = self.get_input(*args, **kwargs) self._write_line(file, (self.eol if lineno.value else '') + json.dumps(row)) lineno += 1 return NOT_MODIFIED diff --git a/bonobo/nodes/io/pickle.py b/bonobo/nodes/io/pickle.py index d9da55f..216c21b 100644 --- a/bonobo/nodes/io/pickle.py +++ b/bonobo/nodes/io/pickle.py @@ -1,9 +1,10 @@ import pickle from bonobo.config import Option +from bonobo.config.options import RemovedOption from bonobo.config.processors import ContextProcessor from bonobo.constants import NOT_MODIFIED -from bonobo.nodes.io.base import FileHandler, IOFormatEnabled +from bonobo.nodes.io.base import FileHandler from bonobo.nodes.io.file import FileReader, FileWriter from bonobo.util.objects import ValueHolder @@ -20,7 +21,7 @@ class PickleHandler(FileHandler): item_names = Option(tuple, required=False) -class PickleReader(IOFormatEnabled, FileReader, PickleHandler): +class PickleReader(FileReader, PickleHandler): """ Reads a Python pickle object and yields the items in dicts. """ @@ -54,10 +55,10 @@ class PickleReader(IOFormatEnabled, FileReader, PickleHandler): if len(i) != item_count: raise ValueError('Received an object with %d items, expecting %d.' % (len(i), item_count, )) - yield self.get_output(dict(zip(i)) if is_dict else dict(zip(pickle_headers.value, i))) + yield dict(zip(i)) if is_dict else dict(zip(pickle_headers.value, i)) -class PickleWriter(IOFormatEnabled, FileWriter, PickleHandler): +class PickleWriter(FileWriter, PickleHandler): mode = Option(str, default='wb') def write(self, fs, file, lineno, item): diff --git a/bonobo/settings.py b/bonobo/settings.py index e5edd83..ef4be2d 100644 --- a/bonobo/settings.py +++ b/bonobo/settings.py @@ -42,6 +42,9 @@ class Setting: def __repr__(self): return ''.format(self.name, self.get()) + def __eq__(self, other): + return self.get() == other + def set(self, value): value = self.formatter(value) if self.formatter else value if self.validator and not self.validator(value): diff --git a/bonobo/strategies/base.py b/bonobo/strategies/base.py index 4b345d4..47f7db4 100644 --- a/bonobo/strategies/base.py +++ b/bonobo/strategies/base.py @@ -6,10 +6,13 @@ class Strategy: Base class for execution strategies. """ - graph_execution_context_factory = GraphExecutionContext + GraphExecutionContextType = GraphExecutionContext - def create_graph_execution_context(self, graph, *args, **kwargs): - return self.graph_execution_context_factory(graph, *args, **kwargs) + def __init__(self, GraphExecutionContextType=None): + self.GraphExecutionContextType = GraphExecutionContextType or self.GraphExecutionContextType + + def create_graph_execution_context(self, graph, *args, GraphExecutionContextType=None, **kwargs): + return (GraphExecutionContextType or self.GraphExecutionContextType)(graph, *args, **kwargs) def execute(self, graph, *args, **kwargs): raise NotImplementedError diff --git a/bonobo/strategies/executor.py b/bonobo/strategies/executor.py index a0bd4f4..3bfabc6 100644 --- a/bonobo/strategies/executor.py +++ b/bonobo/strategies/executor.py @@ -19,42 +19,16 @@ class ExecutorStrategy(Strategy): def create_executor(self): return self.executor_factory() - def execute(self, graph, *args, plugins=None, services=None, **kwargs): - context = self.create_graph_execution_context(graph, plugins=plugins, services=services) + def execute(self, graph, **kwargs): + context = self.create_graph_execution_context(graph, **kwargs) context.write(BEGIN, Bag(), END) executor = self.create_executor() futures = [] - for plugin_context in context.plugins: - - def _runner(plugin_context=plugin_context): - with plugin_context: - try: - plugin_context.loop() - except Exception as exc: - print_error(exc, traceback.format_exc(), context=plugin_context) - - futures.append(executor.submit(_runner)) - - for node_context in context.nodes: - - def _runner(node_context=node_context): - try: - node_context.start() - except Exception as exc: - print_error(exc, traceback.format_exc(), context=node_context, method='start') - node_context.input.on_end() - else: - node_context.loop() - - try: - node_context.stop() - except Exception as exc: - print_error(exc, traceback.format_exc(), context=node_context, method='stop') - - futures.append(executor.submit(_runner)) + context.start_plugins(self.get_plugin_starter(executor, futures)) + context.start(self.get_starter(executor, futures)) while context.alive: time.sleep(0.1) @@ -62,10 +36,45 @@ class ExecutorStrategy(Strategy): for plugin_context in context.plugins: plugin_context.shutdown() + context.stop() + executor.shutdown() return context + def get_starter(self, executor, futures): + def starter(node): + def _runner(): + try: + node.start() + except Exception as exc: + print_error(exc, traceback.format_exc(), context=node, method='start') + node.input.on_end() + else: + node.loop() + + try: + node.stop() + except Exception as exc: + print_error(exc, traceback.format_exc(), context=node, method='stop') + + futures.append(executor.submit(_runner)) + + return starter + + def get_plugin_starter(self, executor, futures): + def plugin_starter(plugin): + def _runner(): + with plugin: + try: + plugin.loop() + except Exception as exc: + print_error(exc, traceback.format_exc(), context=plugin) + + futures.append(executor.submit(_runner)) + + return plugin_starter + class ThreadPoolExecutorStrategy(ExecutorStrategy): executor_factory = ThreadPoolExecutor diff --git a/bonobo/strategies/naive.py b/bonobo/strategies/naive.py index cab9c57..20477c1 100644 --- a/bonobo/strategies/naive.py +++ b/bonobo/strategies/naive.py @@ -4,13 +4,23 @@ from bonobo.structs.bags import Bag class NaiveStrategy(Strategy): - def execute(self, graph, *args, plugins=None, **kwargs): - context = self.create_graph_execution_context(graph, plugins=plugins) + # TODO: how to run plugins in "naive" mode ? + + def execute(self, graph, **kwargs): + context = self.create_graph_execution_context(graph, **kwargs) context.write(BEGIN, Bag(), END) - # TODO: how to run plugins in "naive" mode ? + # start context.start() - context.loop() + + # loop + nodes = list(context.nodes) + while len(nodes): + for node in nodes: + node.loop() + nodes = list(node for node in nodes if node.alive) + + # stop context.stop() return context diff --git a/bonobo/structs/bags.py b/bonobo/structs/bags.py index 3eae9ff..31bc870 100644 --- a/bonobo/structs/bags.py +++ b/bonobo/structs/bags.py @@ -96,7 +96,29 @@ class Bag: return cls(*args, _flags=(INHERIT_INPUT, ), **kwargs) def __eq__(self, other): - return isinstance(other, Bag) and other.args == self.args and other.kwargs == self.kwargs + # XXX there are overlapping cases, but this is very handy for now. Let's think about it later. + + # bag + if isinstance(other, Bag) and other.args == self.args and other.kwargs == self.kwargs: + return True + + # tuple of (tuple, dict) + if isinstance(other, tuple) and len(other) == 2 and other[0] == self.args and other[1] == self.kwargs: + return True + + # tuple (aka args) + if isinstance(other, tuple) and other == self.args: + return True + + # dict (aka kwargs) + if isinstance(other, dict) and not self.args and other == self.kwargs: + return True + + # arg0 + if len(self.args) == 1 and not len(self.kwargs) and self.args[0] == other: + return True + + return False def __repr__(self): return '<{} ({})>'.format( diff --git a/bonobo/util/__init__.py b/bonobo/util/__init__.py index df14e9a..e2eebe1 100644 --- a/bonobo/util/__init__.py +++ b/bonobo/util/__init__.py @@ -1,14 +1,18 @@ from bonobo.util.collections import sortedlist +from bonobo.util.iterators import ensure_tuple +from bonobo.util.compat import deprecated, deprecated_alias from bonobo.util.inspect import ( inspect_node, isbag, isconfigurable, isconfigurabletype, iscontextprocessor, + isdict, iserrorbag, isloopbackbag, ismethod, isoption, + istuple, istype, ) from bonobo.util.objects import (get_name, get_attribute_or_create, ValueHolder) @@ -17,6 +21,8 @@ from bonobo.util.python import require # Bonobo's util API __all__ = [ 'ValueHolder', + 'deprecated', + 'deprecated_alias', 'get_attribute_or_create', 'get_name', 'inspect_node', @@ -24,6 +30,7 @@ __all__ = [ 'isconfigurable', 'isconfigurabletype', 'iscontextprocessor', + 'isdict', 'iserrorbag', 'isloopbackbag', 'ismethod', diff --git a/bonobo/util/inspect.py b/bonobo/util/inspect.py index f9ae4d8..a3c71d7 100644 --- a/bonobo/util/inspect.py +++ b/bonobo/util/inspect.py @@ -68,6 +68,26 @@ def istype(mixed): return isinstance(mixed, type) +def isdict(mixed): + """ + Check if the given argument is a dict. + + :param mixed: + :return: bool + """ + return isinstance(mixed, dict) + + +def istuple(mixed): + """ + Check if the given argument is a tuple. + + :param mixed: + :return: bool + """ + return isinstance(mixed, tuple) + + def isbag(mixed): """ Check if the given argument is an instance of a :class:`bonobo.Bag`. diff --git a/bonobo/util/iterators.py b/bonobo/util/iterators.py index 82f8518..04c81a5 100644 --- a/bonobo/util/iterators.py +++ b/bonobo/util/iterators.py @@ -38,6 +38,6 @@ def tuplize(generator): def iter_if_not_sequence(mixed): - if isinstance(mixed, (dict, list, str)): + if isinstance(mixed, (dict, list, str, bytes, )): raise TypeError(type(mixed).__name__) return iter(mixed) diff --git a/bonobo/util/testing.py b/bonobo/util/testing.py index 7c07256..6fc7d60 100644 --- a/bonobo/util/testing.py +++ b/bonobo/util/testing.py @@ -1,16 +1,10 @@ from contextlib import contextmanager -from unittest.mock import MagicMock -from bonobo import open_fs +from bonobo import open_fs, Token +from bonobo.execution import GraphExecutionContext from bonobo.execution.node import NodeExecutionContext -class CapturingNodeExecutionContext(NodeExecutionContext): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.send = MagicMock() - - @contextmanager def optional_contextmanager(cm, *, ignore=False): if cm is None or ignore: @@ -35,3 +29,38 @@ class FilesystemTester: def get_services_for_writer(self, tmpdir): fs, filename = open_fs(tmpdir), 'output.' + self.extension return fs, filename, {'fs': fs} + + +class QueueList(list): + def append(self, item): + if not isinstance(item, Token): + super(QueueList, self).append(item) + + put = append + + +class BufferingContext: + def __init__(self, buffer=None): + if buffer is None: + buffer = QueueList() + self.buffer = buffer + + def get_buffer(self): + return self.buffer + + +class BufferingNodeExecutionContext(BufferingContext, NodeExecutionContext): + def __init__(self, *args, buffer=None, **kwargs): + BufferingContext.__init__(self, buffer) + NodeExecutionContext.__init__(self, *args, **kwargs, _outputs=[self.buffer]) + + +class BufferingGraphExecutionContext(BufferingContext, GraphExecutionContext): + NodeExecutionContextType = BufferingNodeExecutionContext + + def __init__(self, *args, buffer=None, **kwargs): + BufferingContext.__init__(self, buffer) + GraphExecutionContext.__init__(self, *args, **kwargs) + + def create_node_execution_context_for(self, node): + return self.NodeExecutionContextType(node, parent=self, buffer=self.buffer) diff --git a/tests/execution/test_node.py b/tests/execution/test_node.py new file mode 100644 index 0000000..23748d4 --- /dev/null +++ b/tests/execution/test_node.py @@ -0,0 +1,104 @@ +from bonobo import Bag, Graph +from bonobo.strategies import NaiveStrategy +from bonobo.util.testing import BufferingNodeExecutionContext, BufferingGraphExecutionContext + + +def test_node_string(): + def f(): + return 'foo' + + with BufferingNodeExecutionContext(f) as context: + context.write_sync(Bag()) + output = context.get_buffer() + + assert len(output) == 1 + assert output[0] == (('foo', ), {}) + + def g(): + yield 'foo' + yield 'bar' + + with BufferingNodeExecutionContext(g) as context: + context.write_sync(Bag()) + output = context.get_buffer() + + assert len(output) == 2 + assert output[0] == (('foo', ), {}) + assert output[1] == (('bar', ), {}) + + +def test_node_bytes(): + def f(): + return b'foo' + + with BufferingNodeExecutionContext(f) as context: + context.write_sync(Bag()) + + output = context.get_buffer() + assert len(output) == 1 + assert output[0] == ((b'foo', ), {}) + + def g(): + yield b'foo' + yield b'bar' + + with BufferingNodeExecutionContext(g) as context: + context.write_sync(Bag()) + output = context.get_buffer() + + assert len(output) == 2 + assert output[0] == ((b'foo', ), {}) + assert output[1] == ((b'bar', ), {}) + + +def test_node_dict(): + def f(): + return {'id': 1, 'name': 'foo'} + + with BufferingNodeExecutionContext(f) as context: + context.write_sync(Bag()) + output = context.get_buffer() + + assert len(output) == 1 + assert output[0] == {'id': 1, 'name': 'foo'} + + def g(): + yield {'id': 1, 'name': 'foo'} + yield {'id': 2, 'name': 'bar'} + + with BufferingNodeExecutionContext(g) as context: + context.write_sync(Bag()) + output = context.get_buffer() + + assert len(output) == 2 + assert output[0] == {'id': 1, 'name': 'foo'} + assert output[1] == {'id': 2, 'name': 'bar'} + + +def test_node_dict_chained(): + strategy = NaiveStrategy(GraphExecutionContextType=BufferingGraphExecutionContext) + + def uppercase_name(**kwargs): + return {**kwargs, 'name': kwargs['name'].upper()} + + def f(): + return {'id': 1, 'name': 'foo'} + + graph = Graph(f, uppercase_name) + context = strategy.execute(graph) + output = context.get_buffer() + + assert len(output) == 1 + assert output[0] == {'id': 1, 'name': 'FOO'} + + def g(): + yield {'id': 1, 'name': 'foo'} + yield {'id': 2, 'name': 'bar'} + + graph = Graph(g, uppercase_name) + context = strategy.execute(graph) + output = context.get_buffer() + + assert len(output) == 2 + assert output[0] == {'id': 1, 'name': 'FOO'} + assert output[1] == {'id': 2, 'name': 'BAR'} diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index 9a9480c..fc189ac 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -3,25 +3,19 @@ import pytest from bonobo import Bag, CsvReader, CsvWriter, settings from bonobo.constants import BEGIN, END from bonobo.execution.node import NodeExecutionContext -from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester +from bonobo.util.testing import FilesystemTester, BufferingNodeExecutionContext csv_tester = FilesystemTester('csv') csv_tester.input_data = 'a,b,c\na foo,b foo,c foo\na bar,b bar,c bar' -def test_write_csv_to_file_arg0(tmpdir): +def test_write_csv_ioformat_arg0(tmpdir): fs, filename, services = csv_tester.get_services_for_writer(tmpdir) + with pytest.raises(ValueError): + CsvWriter(path=filename, ioformat=settings.IOFORMAT_ARG0) - with NodeExecutionContext(CsvWriter(path=filename, ioformat=settings.IOFORMAT_ARG0), services=services) as context: - context.write(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END) - context.step() - context.step() - - with fs.open(filename) as fp: - assert fp.read() == 'foo\nbar\nbaz\n' - - with pytest.raises(AttributeError): - getattr(context, 'file') + with pytest.raises(ValueError): + CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0), @pytest.mark.parametrize('add_kwargs', ({}, { @@ -30,7 +24,7 @@ def test_write_csv_to_file_arg0(tmpdir): def test_write_csv_to_file_kwargs(tmpdir, add_kwargs): fs, filename, services = csv_tester.get_services_for_writer(tmpdir) - with NodeExecutionContext(CsvWriter(path=filename, **add_kwargs), services=services) as context: + with NodeExecutionContext(CsvWriter(filename, **add_kwargs), services=services) as context: context.write(BEGIN, Bag(**{'foo': 'bar'}), Bag(**{'foo': 'baz', 'ignore': 'this'}), END) context.step() context.step() @@ -42,61 +36,24 @@ def test_write_csv_to_file_kwargs(tmpdir, add_kwargs): getattr(context, 'file') -def test_read_csv_from_file_arg0(tmpdir): - fs, filename, services = csv_tester.get_services_for_reader(tmpdir) - - with CapturingNodeExecutionContext( - CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0), - services=services, - ) as context: - context.write(BEGIN, Bag(), END) - context.step() - - assert len(context.send.mock_calls) == 2 - - args0, kwargs0 = context.send.call_args_list[0] - assert len(args0) == 1 and not len(kwargs0) - args1, kwargs1 = context.send.call_args_list[1] - assert len(args1) == 1 and not len(kwargs1) - - assert args0[0].args[0] == { - 'a': 'a foo', - 'b': 'b foo', - 'c': 'c foo', - } - assert args1[0].args[0] == { - 'a': 'a bar', - 'b': 'b bar', - 'c': 'c bar', - } - - def test_read_csv_from_file_kwargs(tmpdir): fs, filename, services = csv_tester.get_services_for_reader(tmpdir) - with CapturingNodeExecutionContext( + with BufferingNodeExecutionContext( CsvReader(path=filename, delimiter=','), services=services, ) as context: context.write(BEGIN, Bag(), END) context.step() + output = context.get_buffer() - assert len(context.send.mock_calls) == 2 - - args0, kwargs0 = context.send.call_args_list[0] - assert len(args0) == 1 and not len(kwargs0) - args1, kwargs1 = context.send.call_args_list[1] - assert len(args1) == 1 and not len(kwargs1) - - _args, _kwargs = args0[0].get() - assert not len(_args) and _kwargs == { + assert len(output) == 2 + assert output[0] == { 'a': 'a foo', 'b': 'b foo', 'c': 'c foo', } - - _args, _kwargs = args1[0].get() - assert not len(_args) and _kwargs == { + assert output[1] == { 'a': 'a bar', 'b': 'b bar', 'c': 'c bar', diff --git a/tests/io/test_file.py b/tests/io/test_file.py index 07a15eb..d7645e7 100644 --- a/tests/io/test_file.py +++ b/tests/io/test_file.py @@ -3,7 +3,7 @@ import pytest from bonobo import Bag, FileReader, FileWriter from bonobo.constants import BEGIN, END from bonobo.execution.node import NodeExecutionContext -from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester +from bonobo.util.testing import BufferingNodeExecutionContext, FilesystemTester txt_tester = FilesystemTester('txt') txt_tester.input_data = 'Hello\nWorld\n' @@ -41,16 +41,10 @@ def test_file_writer_in_context(tmpdir, lines, output): def test_file_reader(tmpdir): fs, filename, services = txt_tester.get_services_for_reader(tmpdir) - with CapturingNodeExecutionContext(FileReader(path=filename), services=services) as context: - context.write(BEGIN, Bag(), END) - context.step() + with BufferingNodeExecutionContext(FileReader(path=filename), services=services) as context: + context.write_sync(Bag()) + output = context.get_buffer() - assert len(context.send.mock_calls) == 2 - - args0, kwargs0 = context.send.call_args_list[0] - assert len(args0) == 1 and not len(kwargs0) - args1, kwargs1 = context.send.call_args_list[1] - assert len(args1) == 1 and not len(kwargs1) - - assert args0[0].args[0] == 'Hello' - assert args1[0].args[0] == 'World' + assert len(output) == 2 + assert output[0] == 'Hello' + assert output[1] == 'World' diff --git a/tests/io/test_json.py b/tests/io/test_json.py index 75350ce..66c7f94 100644 --- a/tests/io/test_json.py +++ b/tests/io/test_json.py @@ -3,21 +3,20 @@ import pytest from bonobo import Bag, JsonReader, JsonWriter, settings from bonobo.constants import BEGIN, END from bonobo.execution.node import NodeExecutionContext -from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester +from bonobo.util.testing import FilesystemTester json_tester = FilesystemTester('json') json_tester.input_data = '''[{"x": "foo"},{"x": "bar"}]''' -def test_write_json_arg0(tmpdir): +def test_write_json_ioformat_arg0(tmpdir): fs, filename, services = json_tester.get_services_for_writer(tmpdir) - with NodeExecutionContext(JsonWriter(filename, ioformat=settings.IOFORMAT_ARG0), services=services) as context: - context.write(BEGIN, Bag({'foo': 'bar'}), END) - context.step() + with pytest.raises(ValueError): + JsonWriter(filename, ioformat=settings.IOFORMAT_ARG0) - with fs.open(filename) as fp: - assert fp.read() == '[{"foo": "bar"}]' + with pytest.raises(ValueError): + JsonReader(filename, ioformat=settings.IOFORMAT_ARG0), @pytest.mark.parametrize('add_kwargs', ({}, { @@ -32,24 +31,3 @@ def test_write_json_kwargs(tmpdir, add_kwargs): with fs.open(filename) as fp: assert fp.read() == '[{"foo": "bar"}]' - - -def test_read_json_arg0(tmpdir): - fs, filename, services = json_tester.get_services_for_reader(tmpdir) - - with CapturingNodeExecutionContext( - JsonReader(filename, ioformat=settings.IOFORMAT_ARG0), - services=services, - ) as context: - context.write(BEGIN, Bag(), END) - context.step() - - assert len(context.send.mock_calls) == 2 - - args0, kwargs0 = context.send.call_args_list[0] - assert len(args0) == 1 and not len(kwargs0) - args1, kwargs1 = context.send.call_args_list[1] - assert len(args1) == 1 and not len(kwargs1) - - assert args0[0].args[0] == {'x': 'foo'} - assert args1[0].args[0] == {'x': 'bar'} diff --git a/tests/io/test_pickle.py b/tests/io/test_pickle.py index aff7796..eca3493 100644 --- a/tests/io/test_pickle.py +++ b/tests/io/test_pickle.py @@ -2,10 +2,9 @@ import pickle import pytest -from bonobo import Bag, PickleReader, PickleWriter, settings -from bonobo.constants import BEGIN, END +from bonobo import Bag, PickleReader, PickleWriter from bonobo.execution.node import NodeExecutionContext -from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester +from bonobo.util.testing import BufferingNodeExecutionContext, FilesystemTester pickle_tester = FilesystemTester('pkl', mode='wb') pickle_tester.input_data = pickle.dumps([['a', 'b', 'c'], ['a foo', 'b foo', 'c foo'], ['a bar', 'b bar', 'c bar']]) @@ -14,10 +13,8 @@ pickle_tester.input_data = pickle.dumps([['a', 'b', 'c'], ['a foo', 'b foo', 'c def test_write_pickled_dict_to_file(tmpdir): fs, filename, services = pickle_tester.get_services_for_writer(tmpdir) - with NodeExecutionContext(PickleWriter(filename, ioformat=settings.IOFORMAT_ARG0), services=services) as context: - context.write(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END) - context.step() - context.step() + with NodeExecutionContext(PickleWriter(filename), services=services) as context: + context.write_sync(Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'})) with fs.open(filename, 'rb') as fp: assert pickle.loads(fp.read()) == {'foo': 'bar'} @@ -29,25 +26,17 @@ def test_write_pickled_dict_to_file(tmpdir): def test_read_pickled_list_from_file(tmpdir): fs, filename, services = pickle_tester.get_services_for_reader(tmpdir) - with CapturingNodeExecutionContext( - PickleReader(filename, ioformat=settings.IOFORMAT_ARG0), services=services - ) as context: - context.write(BEGIN, Bag(), END) - context.step() + with BufferingNodeExecutionContext(PickleReader(filename), services=services) as context: + context.write_sync(Bag()) + output = context.get_buffer() - assert len(context.send.mock_calls) == 2 - - args0, kwargs0 = context.send.call_args_list[0] - assert len(args0) == 1 and not len(kwargs0) - args1, kwargs1 = context.send.call_args_list[1] - assert len(args1) == 1 and not len(kwargs1) - - assert args0[0].args[0] == { + assert len(output) == 2 + assert output[0] == { 'a': 'a foo', 'b': 'b foo', 'c': 'c foo', } - assert args1[0].args[0] == { + assert output[1] == { 'a': 'a bar', 'b': 'b bar', 'c': 'c bar', diff --git a/tests/test_basicusage.py b/tests/test_basicusage.py index a43831d..58a1212 100644 --- a/tests/test_basicusage.py +++ b/tests/test_basicusage.py @@ -12,4 +12,5 @@ def test_run_graph_noop(): with patch('bonobo._api._is_interactive_console', side_effect=lambda: False): result = bonobo.run(graph) + assert isinstance(result, GraphExecutionContext) diff --git a/tests/test_commands.py b/tests/test_commands.py index a29465c..59cf5f4 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -30,9 +30,13 @@ def test_entrypoint(): for command in pkg_resources.iter_entry_points('bonobo.commands'): commands[command.name] = command - assert 'init' in commands - assert 'run' in commands - assert 'version' in commands + assert not { + 'convert', + 'init', + 'inspect', + 'run', + 'version', + }.difference(set(commands)) @all_runners diff --git a/tests/test_execution.py b/tests/test_execution.py index 70e12ac..6fb33e4 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -51,31 +51,31 @@ def test_simple_execution_context(): graph = Graph() graph.add_chain(*chain) - ctx = GraphExecutionContext(graph) - assert len(ctx.nodes) == len(chain) - assert not len(ctx.plugins) + context = GraphExecutionContext(graph) + assert len(context.nodes) == len(chain) + assert not len(context.plugins) for i, node in enumerate(chain): - assert ctx[i].wrapped is node + assert context[i].wrapped is node - assert not ctx.alive - assert not ctx.started - assert not ctx.stopped + assert not context.alive + assert not context.started + assert not context.stopped - ctx.write(BEGIN, Bag(), END) + context.write(BEGIN, Bag(), END) - assert not ctx.alive - assert not ctx.started - assert not ctx.stopped + assert not context.alive + assert not context.started + assert not context.stopped - ctx.start() + context.start() - assert ctx.alive - assert ctx.started - assert not ctx.stopped + assert context.alive + assert context.started + assert not context.stopped - ctx.stop() + context.stop() - assert not ctx.alive - assert ctx.started - assert ctx.stopped + assert not context.alive + assert context.started + assert context.stopped