From c87775f090712b1f754a537863438790f4f597b9 Mon Sep 17 00:00:00 2001 From: Romain Dorgueil Date: Sun, 12 Nov 2017 14:22:29 +0100 Subject: [PATCH] Core: refactoring contexts with more logical responsibilities, stopping to rely on kargs ordering for compat with python3.5 --- .gitignore | 1 + bonobo/config/options.py | 31 +++- bonobo/constants.py | 4 + bonobo/execution/contexts/base.py | 119 ++++++------- bonobo/execution/contexts/node.py | 227 +++++++++++++----------- bonobo/execution/contexts/plugin.py | 4 +- bonobo/execution/strategies/executor.py | 11 +- bonobo/nodes/basics.py | 28 +-- bonobo/nodes/io/__init__.py | 2 +- bonobo/nodes/io/base.py | 3 - bonobo/nodes/io/csv.py | 97 ++++++---- bonobo/nodes/io/pickle.py | 1 - bonobo/structs/bags.py | 16 +- bonobo/util/collections.py | 2 + bonobo/util/testing.py | 5 +- tests/config/test_methods.py | 2 +- tests/nodes/io/test_csv.py | 24 +-- 17 files changed, 325 insertions(+), 252 deletions(-) diff --git a/.gitignore b/.gitignore index db473d4..ae199da 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ *.so *.spec .*.sw? +.DS_Store .Python .cache .coverage diff --git a/bonobo/config/options.py b/bonobo/config/options.py index ec02543..51702c3 100644 --- a/bonobo/config/options.py +++ b/bonobo/config/options.py @@ -1,3 +1,5 @@ +import types + from bonobo.util.inspect import istype @@ -143,10 +145,23 @@ class Method(Option): >>> example3 = OtherChildMethodExample() + It's possible to pass a default implementation to a Method by calling it, making it suitable to use as a decorator. + + >>> class MethodExampleWithDefault(Configurable): + ... @Method() + ... def handler(self): + ... pass + """ - def __init__(self, *, required=True, positional=True): - super().__init__(None, required=required, positional=positional) + def __init__(self, *, required=True, positional=True, __doc__=None): + super().__init__(None, required=required, positional=positional, __doc__=__doc__) + + def __get__(self, inst, type_): + x = super(Method, self).__get__(inst, type_) + if inst: + x = types.MethodType(x, inst) + return x def __set__(self, inst, value): if not hasattr(value, '__call__'): @@ -157,6 +172,12 @@ class Method(Option): ) inst._options_values[self.name] = self.type(value) if self.type else value - def __call__(self, *args, **kwargs): - # only here to trick IDEs into thinking this is callable. - raise NotImplementedError('You cannot call the descriptor') + def __call__(self, impl): + if self.default: + raise RuntimeError('Can only be used once as a decorator.') + self.default = impl + self.required = False + return self + + def get_default(self): + return self.default diff --git a/bonobo/constants.py b/bonobo/constants.py index 8c6eba5..7f20fcd 100644 --- a/bonobo/constants.py +++ b/bonobo/constants.py @@ -7,3 +7,7 @@ LOOPBACK = Token('Loopback') NOT_MODIFIED = Token('NotModified') DEFAULT_SERVICES_FILENAME = '_services.py' DEFAULT_SERVICES_ATTR = 'get_services' + +TICK_PERIOD = 0.2 + +ARGNAMES = '_argnames' diff --git a/bonobo/execution/contexts/base.py b/bonobo/execution/contexts/base.py index 3ca580a..847633b 100644 --- a/bonobo/execution/contexts/base.py +++ b/bonobo/execution/contexts/base.py @@ -1,14 +1,10 @@ import logging import sys from contextlib import contextmanager -from logging import WARNING, ERROR +from logging import ERROR -import mondrian -from bonobo.config import create_container -from bonobo.config.processors import ContextCurrifier -from bonobo.execution import logger -from bonobo.util import isconfigurabletype from bonobo.util.objects import Wrapper, get_name +from mondrian import term @contextmanager @@ -28,8 +24,12 @@ def unrecoverable(error_handler): raise # raise unrecoverableerror from x ? -class LoopingExecutionContext(Wrapper): - PERIOD = 0.5 +class Lifecycle: + def __init__(self): + self._started = False + self._stopped = False + self._killed = False + self._defunct = False @property def started(self): @@ -39,6 +39,10 @@ class LoopingExecutionContext(Wrapper): def stopped(self): return self._stopped + @property + def killed(self): + return self._killed + @property def defunct(self): return self._defunct @@ -47,6 +51,11 @@ class LoopingExecutionContext(Wrapper): def alive(self): return self._started and not self._stopped + @property + def should_loop(self): + # TODO XXX started/stopped? + return not any((self.defunct, self.killed)) + @property def status(self): """One character status for this node. """ @@ -58,23 +67,6 @@ class LoopingExecutionContext(Wrapper): return '+' return '-' - def __init__(self, wrapped, parent, services=None): - super().__init__(wrapped) - - self.parent = parent - - if services: - if parent: - raise RuntimeError( - 'Having services defined both in GraphExecutionContext and child NodeExecutionContext is not supported, for now.' - ) - self.services = create_container(services) - else: - self.services = None - - self._started, self._stopped, self._defunct = False, False, False - self._stack = None - def __enter__(self): self.start() return self @@ -82,57 +74,54 @@ class LoopingExecutionContext(Wrapper): def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): self.stop() + def get_flags_as_string(self): + if self._defunct: + return term.red('[defunct]') + if self.killed: + return term.lightred('[killed]') + if self.stopped: + return term.lightblack('[done]') + return '' + def start(self): if self.started: - raise RuntimeError('Cannot start a node twice ({}).'.format(get_name(self))) + raise RuntimeError('This context is already started ({}).'.format(get_name(self))) self._started = True - try: - self._stack = ContextCurrifier(self.wrapped, *self._get_initial_context()) - if isconfigurabletype(self.wrapped): - # Not normal to have a partially configured object here, so let's warn the user instead of having get into - # the hard trouble of understanding that by himself. - raise TypeError( - 'The Configurable should be fully instanciated by now, unfortunately I got a PartiallyConfigured object...' - ) - self._stack.setup(self) - except Exception: - return self.fatal(sys.exc_info()) - - def loop(self): - """Generic loop. A bit boring. """ - while self.alive: - self.step() - - def step(self): - """Left as an exercise for the children.""" - raise NotImplementedError('Abstract.') - def stop(self): if not self.started: - raise RuntimeError('Cannot stop an unstarted node ({}).'.format(get_name(self))) + raise RuntimeError('This context cannot be stopped as it never started ({}).'.format(get_name(self))) - if self._stopped: + self._stopped = True + + if self._stopped: # Stopping twice has no effect return - try: - if self._stack: - self._stack.teardown() - finally: - self._stopped = True + def kill(self): + if not self.started: + raise RuntimeError('Cannot kill an unstarted context.') - def _get_initial_context(self): - if self.parent: - return self.parent.services.args_for(self.wrapped) - if self.services: - return self.services.args_for(self.wrapped) - return () + if self.stopped: + raise RuntimeError('Cannot kill a stopped context.') - def handle_error(self, exctype, exc, tb, *, level=logging.ERROR): - logging.getLogger(__name__).log(level, repr(self), exc_info=(exctype, exc, tb)) + self._killed = True - def fatal(self, exc_info): + def fatal(self, exc_info, *, level=logging.CRITICAL): + logging.getLogger(__name__).log(level, repr(self), exc_info=exc_info) self._defunct = True - self.input.shutdown() - self.handle_error(*exc_info, level=logging.CRITICAL) + + def as_dict(self): + return { + 'status': self.status, + 'name': self.name, + 'stats': self.get_statistics_as_string(), + 'flags': self.get_flags_as_string(), + } + + +class BaseContext(Lifecycle, Wrapper): + def __init__(self, wrapped, *, parent=None): + Lifecycle.__init__(self) + Wrapper.__init__(self, wrapped) + self.parent = parent diff --git a/bonobo/execution/contexts/node.py b/bonobo/execution/contexts/node.py index a18f4d7..3cb3521 100644 --- a/bonobo/execution/contexts/node.py +++ b/bonobo/execution/contexts/node.py @@ -1,38 +1,44 @@ import logging import sys -import warnings from queue import Empty from time import sleep from types import GeneratorType -from bonobo.constants import NOT_MODIFIED, BEGIN, END +from bonobo.config import create_container +from bonobo.config.processors import ContextCurrifier +from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD from bonobo.errors import InactiveReadableError, UnrecoverableError -from bonobo.execution.contexts.base import LoopingExecutionContext +from bonobo.execution.contexts.base import BaseContext from bonobo.structs.bags import Bag from bonobo.structs.inputs import Input from bonobo.structs.tokens import Token -from bonobo.util import get_name, iserrorbag, isloopbackbag, isbag, istuple -from bonobo.util.compat import deprecated_alias +from bonobo.util import get_name, iserrorbag, isloopbackbag, isbag, istuple, isconfigurabletype from bonobo.util.statistics import WithStatistics -from mondrian import term + +logger = logging.getLogger(__name__) -class NodeExecutionContext(WithStatistics, LoopingExecutionContext): - """ - todo: make the counter dependant of parent context? - """ - - @property - def killed(self): - return self._killed - - def __init__(self, wrapped, parent=None, services=None, _input=None, _outputs=None): - LoopingExecutionContext.__init__(self, wrapped, parent=parent, services=services) +class NodeExecutionContext(BaseContext, WithStatistics): + def __init__(self, wrapped, *, parent=None, services=None, _input=None, _outputs=None): + BaseContext.__init__(self, wrapped, parent=parent) WithStatistics.__init__(self, 'in', 'out', 'err', 'warn') + # Services: how we'll access external dependencies + if services: + if self.parent: + raise RuntimeError( + 'Having services defined both in GraphExecutionContext and child NodeExecutionContext is not supported, for now.' + ) + self.services = create_container(services) + else: + self.services = None + + # Input / Output: how the wrapped node will communicate self.input = _input or Input() self.outputs = _outputs or [] - self._killed = False + + # Stack: context decorators for the execution + self._stack = None def __str__(self): return self.__name__ + self.get_statistics_as_string(prefix=' ') @@ -41,14 +47,94 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext): name, type_name = get_name(self), get_name(type(self)) return '<{}({}{}){}>'.format(type_name, self.status, name, self.get_statistics_as_string(prefix=' ')) - def get_flags_as_string(self): - if self._defunct: - return term.red('[defunct]') - if self.killed: - return term.lightred('[killed]') - if self.stopped: - return term.lightblack('[done]') - return '' + def start(self): + super().start() + + try: + self._stack = ContextCurrifier(self.wrapped, *self._get_initial_context()) + if isconfigurabletype(self.wrapped): + # Not normal to have a partially configured object here, so let's warn the user instead of having get into + # the hard trouble of understanding that by himself. + raise TypeError( + 'The Configurable should be fully instanciated by now, unfortunately I got a PartiallyConfigured object...' + ) + self._stack.setup(self) + except Exception: + return self.fatal(sys.exc_info()) + + def loop(self): + logger.debug('Node loop starts for {!r}.'.format(self)) + while self.should_loop: + try: + self.step() + except InactiveReadableError: + break + except Empty: + sleep(TICK_PERIOD) # XXX: How do we determine this constant? + continue + except UnrecoverableError: + self.handle_error(*sys.exc_info()) + self.input.shutdown() + break + except Exception: # pylint: disable=broad-except + self.handle_error(*sys.exc_info()) + except BaseException: + self.handle_error(*sys.exc_info()) + break + logger.debug('Node loop ends for {!r}.'.format(self)) + + def step(self): + """Runs a transformation callable with given args/kwargs and flush the result into the right + output channel.""" + + # Pull data + input_bag = self.get() + + # Sent through the stack + try: + results = input_bag.apply(self._stack) + except Exception: + return self.handle_error(*sys.exc_info()) + + # self._exec_time += timer.duration + # Put data onto output channels + + if isinstance(results, GeneratorType): + while True: + try: + # if kill flag was step, stop iterating. + if self._killed: + break + result = next(results) + except StopIteration: + # That's not an error, we're just done. + break + except Exception: + # Let's kill this loop, won't be able to generate next. + self.handle_error(*sys.exc_info()) + break + else: + self.send(_resolve(input_bag, result)) + elif results: + self.send(_resolve(input_bag, results)) + else: + # case with no result, an execution went through anyway, use for stats. + # self._exec_count += 1 + pass + + def stop(self): + if self._stack: + self._stack.teardown() + + super().stop() + + def handle_error(self, exctype, exc, tb, *, level=logging.ERROR): + self.increment('err') + logging.getLogger(__name__).log(level, repr(self), exc_info=(exctype, exc, tb)) + + def fatal(self, exc_info, *, level=logging.CRITICAL): + super().fatal(exc_info, level=level) + self.input.shutdown() def write(self, *messages): """ @@ -64,9 +150,6 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext): for _ in messages: self.step() - # XXX deprecated alias - recv = deprecated_alias('recv', write) - def send(self, value, _control=False): """ Sends a message to all of this context's outputs. @@ -86,89 +169,25 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext): for output in self.outputs: output.put(value) - push = deprecated_alias('push', send) - - def get(self): # recv() ? input_data = self.receive() + def get(self): """ Get from the queue first, then increment stats, so if Queue raise Timeout or Empty, stat won't be changed. """ - row = self.input.get(timeout=self.PERIOD) + row = self.input.get() # XXX TIMEOUT ??? self.increment('in') return row - def should_loop(self): - return not any((self.defunct, self.killed)) - - def loop(self): - while self.should_loop(): - try: - self.step() - except InactiveReadableError: - break - except Empty: - sleep(self.PERIOD) - continue - except UnrecoverableError: - self.handle_error(*sys.exc_info()) - self.input.shutdown() - break - except Exception: # pylint: disable=broad-except - self.handle_error(*sys.exc_info()) - except BaseException: - self.handle_error(*sys.exc_info()) - break - - def step(self): - # Pull data from the first available input channel. - """Runs a transformation callable with given args/kwargs and flush the result into the right - output channel.""" - - input_bag = self.get() - - results = input_bag.apply(self._stack) - - # self._exec_time += timer.duration - # Put data onto output channels - - if isinstance(results, GeneratorType): - while True: - try: - # if kill flag was step, stop iterating. - if self._killed: - break - result = next(results) - except StopIteration: - break - else: - self.send(_resolve(input_bag, result)) - elif results: - self.send(_resolve(input_bag, results)) - else: - # case with no result, an execution went through anyway, use for stats. - # self._exec_count += 1 - pass - - def kill(self): - if not self.started: - raise RuntimeError('Cannot kill a node context that has not started yet.') - - if self.stopped: - raise RuntimeError('Cannot kill a node context that has already stopped.') - - self._killed = True - - def as_dict(self): - return { - 'status': self.status, - 'name': self.name, - 'stats': self.get_statistics_as_string(), - 'flags': self.get_flags_as_string(), - } + def _get_initial_context(self): + if self.parent: + return self.parent.services.args_for(self.wrapped) + if self.services: + return self.services.args_for(self.wrapped) + return () def isflag(param): - return isinstance(param, Token) and param in (NOT_MODIFIED,) + return isinstance(param, Token) and param in (NOT_MODIFIED, ) def split_tokens(output): @@ -180,11 +199,11 @@ def split_tokens(output): """ if isinstance(output, Token): # just a flag - return (output,), () + return (output, ), () if not istuple(output): # no flag - return (), (output,) + return (), (output, ) i = 0 while isflag(output[i]): diff --git a/bonobo/execution/contexts/plugin.py b/bonobo/execution/contexts/plugin.py index 524c2e1..3551d0d 100644 --- a/bonobo/execution/contexts/plugin.py +++ b/bonobo/execution/contexts/plugin.py @@ -1,7 +1,7 @@ -from bonobo.execution.contexts.base import LoopingExecutionContext +from bonobo.execution.contexts.base import BaseContext -class PluginExecutionContext(LoopingExecutionContext): +class PluginExecutionContext(BaseContext): @property def dispatcher(self): return self.parent.dispatcher diff --git a/bonobo/execution/strategies/executor.py b/bonobo/execution/strategies/executor.py index f99c4cc..d7a0017 100644 --- a/bonobo/execution/strategies/executor.py +++ b/bonobo/execution/strategies/executor.py @@ -52,15 +52,8 @@ class ExecutorStrategy(Strategy): def starter(node): @functools.wraps(node) def _runner(): - try: - with node: - node.loop() - except: - logging.getLogger(__name__).critical( - 'Uncaught exception in node execution for {}.'.format(node), exc_info=True - ) - node.shutdown() - node.stop() + with node: + node.loop() try: futures.append(executor.submit(_runner)) diff --git a/bonobo/nodes/basics.py b/bonobo/nodes/basics.py index 4054d3d..3a53d2d 100644 --- a/bonobo/nodes/basics.py +++ b/bonobo/nodes/basics.py @@ -4,7 +4,7 @@ import itertools from bonobo import settings from bonobo.config import Configurable, Option from bonobo.config.processors import ContextProcessor -from bonobo.constants import NOT_MODIFIED +from bonobo.constants import NOT_MODIFIED, ARGNAMES from bonobo.structs.bags import Bag from bonobo.util.objects import ValueHolder from bonobo.util.term import CLEAR_EOL @@ -88,18 +88,29 @@ class PrettyPrinter(Configurable): def call(self, *args, **kwargs): formater = self._format_quiet if settings.QUIET.get() else self._format_console + argnames = kwargs.get(ARGNAMES, None) - for i, (item, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())): - print(formater(i, item, value)) + for i, (item, value) in enumerate( + itertools.chain(enumerate(args), filter(lambda x: not x[0].startswith('_'), kwargs.items())) + ): + print(formater(i, item, value, argnames=argnames)) - def _format_quiet(self, i, item, value): + def _format_quiet(self, i, item, value, *, argnames=None): + # XXX should we implement argnames here ? return ' '.join(((' ' if i else '-'), str(item), ':', str(value).strip())) - def _format_console(self, i, item, value): + def _format_console(self, i, item, value, *, argnames=None): + argnames = argnames or [] + if not isinstance(item, str): + if len(argnames) >= item: + item = '{} / {}'.format(item, argnames[item]) + else: + item = str(i) + return ' '.join( ( - (' ' if i else '•'), str(item), '=', _shorten(str(value).strip(), - self.max_width).replace('\n', '\n' + CLEAR_EOL), CLEAR_EOL + (' ' if i else '•'), item, '=', _shorten(str(value).strip(), + self.max_width).replace('\n', '\n' + CLEAR_EOL), CLEAR_EOL ) ) @@ -172,6 +183,3 @@ class FixedWindow(Configurable): if len(buffer) >= self.length: yield buffer.get() buffer.set([]) - - - diff --git a/bonobo/nodes/io/__init__.py b/bonobo/nodes/io/__init__.py index 4e7fbe6..1369ed2 100644 --- a/bonobo/nodes/io/__init__.py +++ b/bonobo/nodes/io/__init__.py @@ -1,8 +1,8 @@ """ Readers and writers for common file formats. """ +from .csv import CsvReader, CsvWriter from .file import FileReader, FileWriter from .json import JsonReader, JsonWriter, LdjsonReader, LdjsonWriter -from .csv import CsvReader, CsvWriter from .pickle import PickleReader, PickleWriter __all__ = [ diff --git a/bonobo/nodes/io/base.py b/bonobo/nodes/io/base.py index db0bc80..af9e609 100644 --- a/bonobo/nodes/io/base.py +++ b/bonobo/nodes/io/base.py @@ -1,7 +1,4 @@ -from fs.errors import ResourceNotFound - from bonobo.config import Configurable, ContextProcessor, Option, Service -from bonobo.errors import UnrecoverableError class FileHandler(Configurable): diff --git a/bonobo/nodes/io/csv.py b/bonobo/nodes/io/csv.py index c846a5f..188fd80 100644 --- a/bonobo/nodes/io/csv.py +++ b/bonobo/nodes/io/csv.py @@ -1,13 +1,13 @@ import csv import warnings -from bonobo.config import Option -from bonobo.config.options import RemovedOption -from bonobo.config.processors import ContextProcessor -from bonobo.constants import NOT_MODIFIED +from bonobo.config import Option, ContextProcessor +from bonobo.config.options import RemovedOption, Method +from bonobo.constants import NOT_MODIFIED, ARGNAMES from bonobo.nodes.io.base import FileHandler from bonobo.nodes.io.file import FileReader, FileWriter -from bonobo.util.objects import ValueHolder +from bonobo.structs.bags import Bag +from bonobo.util import ensure_tuple class CsvHandler(FileHandler): @@ -28,7 +28,7 @@ class CsvHandler(FileHandler): """ delimiter = Option(str, default=';') quotechar = Option(str, default='"') - headers = Option(tuple, required=False) + headers = Option(ensure_tuple, required=False) ioformat = RemovedOption(positional=False, value='kwargs') @@ -44,41 +44,66 @@ class CsvReader(FileReader, CsvHandler): skip = Option(int, default=0) - @ContextProcessor - def csv_headers(self, context, fs, file): - yield ValueHolder(self.headers) + @Method( + __doc__=''' + Builds the CSV reader, a.k.a an object we can iterate, each iteration giving one line of fields, as an + iterable. + + Defaults to builtin csv.reader(...), but can be overriden to fit your special needs. + ''' + ) + def reader_factory(self, file): + return csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar) - def read(self, fs, file, headers): - reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar) - - if not headers.get(): - headers.set(next(reader)) - _headers = headers.get() - - field_count = len(headers) - - if self.skip and self.skip > 0: - for _ in range(0, self.skip): - next(reader) - - for lineno, row in enumerate(reader): - if len(row) != field_count: - warnings.warn('Got %d fields on line #%d, expecting %d.' % (len(row), lineno, field_count,)) - - yield dict(zip(_headers, row)) + def read(self, fs, file): + reader = self.reader_factory(file) + headers = self.headers or next(reader) + for row in reader: + yield Bag(*row, **{ARGNAMES: headers}) class CsvWriter(FileWriter, CsvHandler): @ContextProcessor - def writer(self, context, fs, file, lineno): - writer = csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol) - headers = ValueHolder(list(self.headers) if self.headers else None) - yield writer, headers + def context(self, context, *args): + yield context + + @Method( + __doc__=''' + Builds the CSV writer, a.k.a an object we can pass a field collection to be written as one line in the + target file. + + Defaults to builtin csv.writer(...).writerow, but can be overriden to fit your special needs. + ''' + ) + def writer_factory(self, file): + return csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol).writerow + + def write(self, fs, file, lineno, context, *args, _argnames=None): + try: + writer = context.writer + except AttributeError: + context.writer = self.writer_factory(file) + writer = context.writer + context.headers = self.headers or _argnames + + if context.headers and not lineno: + writer(context.headers) - def write(self, fs, file, lineno, writer, headers, **row): - if not lineno: - headers.set(headers.value or row.keys()) - writer.writerow(headers.get()) - writer.writerow(row.get(header, '') for header in headers.get()) lineno += 1 + + if context.headers: + try: + row = [args[i] for i, header in enumerate(context.headers) if header] + except IndexError: + warnings.warn( + 'At line #{}, expected {} fields but only got {}. Padding with empty strings.'.format( + lineno, len(context.headers), len(args) + ) + ) + row = [(args[i] if i < len(args) else '') for i, header in enumerate(context.headers) if header] + else: + row = args + + writer(row) + return NOT_MODIFIED diff --git a/bonobo/nodes/io/pickle.py b/bonobo/nodes/io/pickle.py index 3bb95d6..bc02ce8 100644 --- a/bonobo/nodes/io/pickle.py +++ b/bonobo/nodes/io/pickle.py @@ -1,7 +1,6 @@ import pickle from bonobo.config import Option -from bonobo.config.options import RemovedOption from bonobo.config.processors import ContextProcessor from bonobo.constants import NOT_MODIFIED from bonobo.nodes.io.base import FileHandler diff --git a/bonobo/structs/bags.py b/bonobo/structs/bags.py index f303f92..0738790 100644 --- a/bonobo/structs/bags.py +++ b/bonobo/structs/bags.py @@ -52,8 +52,9 @@ class Bag: # Otherwise, type will handle that for us. return super().__new__(cls) - def __init__(self, *args, _flags=None, _parent=None, **kwargs): + def __init__(self, *args, _flags=None, _parent=None, _argnames=None, **kwargs): self._flags = type(self).default_flags + (_flags or ()) + self._argnames = _argnames self._parent = _parent if len(args) == 1 and len(kwargs) == 0: @@ -115,9 +116,13 @@ class Bag: def flags(self): return self._flags + @property + def specials(self): + return {k: self.__dict__[k] for k in ('_argnames', ) if k in self.__dict__ and self.__dict__[k]} + def apply(self, func_or_iter, *args, **kwargs): if callable(func_or_iter): - return func_or_iter(*args, *self.args, **kwargs, **self.kwargs) + return func_or_iter(*args, *self.args, **kwargs, **self.kwargs, **self.specials) if len(args) == 0 and len(kwargs) == 0: try: @@ -148,7 +153,7 @@ class Bag: @classmethod def inherit(cls, *args, **kwargs): - return cls(*args, _flags=(INHERIT_INPUT,), **kwargs) + return cls(*args, _flags=(INHERIT_INPUT, ), **kwargs) def __eq__(self, other): # XXX there are overlapping cases, but this is very handy for now. Let's think about it later. @@ -176,9 +181,12 @@ class Bag: return len(self.args) == 1 and not self.kwargs and self.args[0] == other + def args_as_dict(self): + return dict(zip(self._argnames, self.args)) + class LoopbackBag(Bag): - default_flags = (LOOPBACK,) + default_flags = (LOOPBACK, ) class ErrorBag(Bag): diff --git a/bonobo/util/collections.py b/bonobo/util/collections.py index 31765c4..d5a4624 100644 --- a/bonobo/util/collections.py +++ b/bonobo/util/collections.py @@ -16,6 +16,8 @@ def ensure_tuple(tuple_or_mixed): :return: tuple """ + if tuple_or_mixed is None: + return () if isinstance(tuple_or_mixed, tuple): return tuple_or_mixed return (tuple_or_mixed, ) diff --git a/bonobo/util/testing.py b/bonobo/util/testing.py index 9044715..66af870 100644 --- a/bonobo/util/testing.py +++ b/bonobo/util/testing.py @@ -8,7 +8,7 @@ from unittest.mock import patch import pytest -from bonobo import open_fs, Token, __main__, get_examples_path +from bonobo import open_fs, Token, __main__, get_examples_path, Bag from bonobo.commands import entrypoint from bonobo.execution.contexts.graph import GraphExecutionContext from bonobo.execution.contexts.node import NodeExecutionContext @@ -57,6 +57,9 @@ class BufferingContext: def get_buffer(self): return self.buffer + def get_buffer_args_as_dicts(self): + return list(map(lambda x: x.args_as_dict() if isinstance(x, Bag) else dict(x), self.buffer)) + class BufferingNodeExecutionContext(BufferingContext, NodeExecutionContext): def __init__(self, *args, buffer=None, **kwargs): diff --git a/tests/config/test_methods.py b/tests/config/test_methods.py index b0154fb..0a3b423 100644 --- a/tests/config/test_methods.py +++ b/tests/config/test_methods.py @@ -58,7 +58,7 @@ def test_define_with_decorator(): Concrete = MethodBasedConfigurable(my_handler) assert callable(Concrete.handler) - assert Concrete.handler == my_handler + assert Concrete.handler.__func__ == my_handler with inspect_node(Concrete) as ci: assert ci.type == MethodBasedConfigurable diff --git a/tests/nodes/io/test_csv.py b/tests/nodes/io/test_csv.py index b0b91c5..0d713bd 100644 --- a/tests/nodes/io/test_csv.py +++ b/tests/nodes/io/test_csv.py @@ -17,17 +17,21 @@ def test_write_csv_ioformat_arg0(tmpdir): CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0), -@pytest.mark.parametrize('add_kwargs', ( - {}, - { - 'ioformat': settings.IOFORMAT_KWARGS, - }, -)) -def test_write_csv_to_file_kwargs(tmpdir, add_kwargs): +def test_write_csv_to_file_no_headers(tmpdir): fs, filename, services = csv_tester.get_services_for_writer(tmpdir) - with NodeExecutionContext(CsvWriter(filename, **add_kwargs), services=services) as context: - context.write_sync({'foo': 'bar'}, {'foo': 'baz', 'ignore': 'this'}) + with NodeExecutionContext(CsvWriter(filename), services=services) as context: + context.write_sync(('bar', ), ('baz', 'boo')) + + with fs.open(filename) as fp: + assert fp.read() == 'bar\nbaz;boo\n' + + +def test_write_csv_to_file_with_headers(tmpdir): + fs, filename, services = csv_tester.get_services_for_writer(tmpdir) + + with NodeExecutionContext(CsvWriter(filename, headers='foo'), services=services) as context: + context.write_sync(('bar', ), ('baz', 'boo')) with fs.open(filename) as fp: assert fp.read() == 'foo\nbar\nbaz\n' @@ -45,7 +49,7 @@ def test_read_csv_from_file_kwargs(tmpdir): ) as context: context.write_sync(()) - assert context.get_buffer() == [ + assert context.get_buffer_args_as_dicts() == [ { 'a': 'a foo', 'b': 'b foo',