Core: refactoring contexts with more logical responsibilities, stopping to rely on kargs ordering for compat with python3.5

This commit is contained in:
Romain Dorgueil
2017-11-12 14:22:29 +01:00
parent 739a64d8f4
commit c87775f090
17 changed files with 325 additions and 252 deletions

1
.gitignore vendored
View File

@ -11,6 +11,7 @@
*.so *.so
*.spec *.spec
.*.sw? .*.sw?
.DS_Store
.Python .Python
.cache .cache
.coverage .coverage

View File

@ -1,3 +1,5 @@
import types
from bonobo.util.inspect import istype from bonobo.util.inspect import istype
@ -143,10 +145,23 @@ class Method(Option):
>>> example3 = OtherChildMethodExample() >>> example3 = OtherChildMethodExample()
It's possible to pass a default implementation to a Method by calling it, making it suitable to use as a decorator.
>>> class MethodExampleWithDefault(Configurable):
... @Method()
... def handler(self):
... pass
""" """
def __init__(self, *, required=True, positional=True): def __init__(self, *, required=True, positional=True, __doc__=None):
super().__init__(None, required=required, positional=positional) super().__init__(None, required=required, positional=positional, __doc__=__doc__)
def __get__(self, inst, type_):
x = super(Method, self).__get__(inst, type_)
if inst:
x = types.MethodType(x, inst)
return x
def __set__(self, inst, value): def __set__(self, inst, value):
if not hasattr(value, '__call__'): if not hasattr(value, '__call__'):
@ -157,6 +172,12 @@ class Method(Option):
) )
inst._options_values[self.name] = self.type(value) if self.type else value inst._options_values[self.name] = self.type(value) if self.type else value
def __call__(self, *args, **kwargs): def __call__(self, impl):
# only here to trick IDEs into thinking this is callable. if self.default:
raise NotImplementedError('You cannot call the descriptor') raise RuntimeError('Can only be used once as a decorator.')
self.default = impl
self.required = False
return self
def get_default(self):
return self.default

View File

@ -7,3 +7,7 @@ LOOPBACK = Token('Loopback')
NOT_MODIFIED = Token('NotModified') NOT_MODIFIED = Token('NotModified')
DEFAULT_SERVICES_FILENAME = '_services.py' DEFAULT_SERVICES_FILENAME = '_services.py'
DEFAULT_SERVICES_ATTR = 'get_services' DEFAULT_SERVICES_ATTR = 'get_services'
TICK_PERIOD = 0.2
ARGNAMES = '_argnames'

View File

@ -1,14 +1,10 @@
import logging import logging
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from logging import WARNING, ERROR from logging import ERROR
import mondrian
from bonobo.config import create_container
from bonobo.config.processors import ContextCurrifier
from bonobo.execution import logger
from bonobo.util import isconfigurabletype
from bonobo.util.objects import Wrapper, get_name from bonobo.util.objects import Wrapper, get_name
from mondrian import term
@contextmanager @contextmanager
@ -28,8 +24,12 @@ def unrecoverable(error_handler):
raise # raise unrecoverableerror from x ? raise # raise unrecoverableerror from x ?
class LoopingExecutionContext(Wrapper): class Lifecycle:
PERIOD = 0.5 def __init__(self):
self._started = False
self._stopped = False
self._killed = False
self._defunct = False
@property @property
def started(self): def started(self):
@ -39,6 +39,10 @@ class LoopingExecutionContext(Wrapper):
def stopped(self): def stopped(self):
return self._stopped return self._stopped
@property
def killed(self):
return self._killed
@property @property
def defunct(self): def defunct(self):
return self._defunct return self._defunct
@ -47,6 +51,11 @@ class LoopingExecutionContext(Wrapper):
def alive(self): def alive(self):
return self._started and not self._stopped return self._started and not self._stopped
@property
def should_loop(self):
# TODO XXX started/stopped?
return not any((self.defunct, self.killed))
@property @property
def status(self): def status(self):
"""One character status for this node. """ """One character status for this node. """
@ -58,23 +67,6 @@ class LoopingExecutionContext(Wrapper):
return '+' return '+'
return '-' return '-'
def __init__(self, wrapped, parent, services=None):
super().__init__(wrapped)
self.parent = parent
if services:
if parent:
raise RuntimeError(
'Having services defined both in GraphExecutionContext and child NodeExecutionContext is not supported, for now.'
)
self.services = create_container(services)
else:
self.services = None
self._started, self._stopped, self._defunct = False, False, False
self._stack = None
def __enter__(self): def __enter__(self):
self.start() self.start()
return self return self
@ -82,57 +74,54 @@ class LoopingExecutionContext(Wrapper):
def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): def __exit__(self, exc_type=None, exc_val=None, exc_tb=None):
self.stop() self.stop()
def get_flags_as_string(self):
if self._defunct:
return term.red('[defunct]')
if self.killed:
return term.lightred('[killed]')
if self.stopped:
return term.lightblack('[done]')
return ''
def start(self): def start(self):
if self.started: if self.started:
raise RuntimeError('Cannot start a node twice ({}).'.format(get_name(self))) raise RuntimeError('This context is already started ({}).'.format(get_name(self)))
self._started = True self._started = True
try:
self._stack = ContextCurrifier(self.wrapped, *self._get_initial_context())
if isconfigurabletype(self.wrapped):
# Not normal to have a partially configured object here, so let's warn the user instead of having get into
# the hard trouble of understanding that by himself.
raise TypeError(
'The Configurable should be fully instanciated by now, unfortunately I got a PartiallyConfigured object...'
)
self._stack.setup(self)
except Exception:
return self.fatal(sys.exc_info())
def loop(self):
"""Generic loop. A bit boring. """
while self.alive:
self.step()
def step(self):
"""Left as an exercise for the children."""
raise NotImplementedError('Abstract.')
def stop(self): def stop(self):
if not self.started: if not self.started:
raise RuntimeError('Cannot stop an unstarted node ({}).'.format(get_name(self))) raise RuntimeError('This context cannot be stopped as it never started ({}).'.format(get_name(self)))
if self._stopped: self._stopped = True
if self._stopped: # Stopping twice has no effect
return return
try: def kill(self):
if self._stack: if not self.started:
self._stack.teardown() raise RuntimeError('Cannot kill an unstarted context.')
finally:
self._stopped = True
def _get_initial_context(self): if self.stopped:
if self.parent: raise RuntimeError('Cannot kill a stopped context.')
return self.parent.services.args_for(self.wrapped)
if self.services:
return self.services.args_for(self.wrapped)
return ()
def handle_error(self, exctype, exc, tb, *, level=logging.ERROR): self._killed = True
logging.getLogger(__name__).log(level, repr(self), exc_info=(exctype, exc, tb))
def fatal(self, exc_info): def fatal(self, exc_info, *, level=logging.CRITICAL):
logging.getLogger(__name__).log(level, repr(self), exc_info=exc_info)
self._defunct = True self._defunct = True
self.input.shutdown()
self.handle_error(*exc_info, level=logging.CRITICAL) def as_dict(self):
return {
'status': self.status,
'name': self.name,
'stats': self.get_statistics_as_string(),
'flags': self.get_flags_as_string(),
}
class BaseContext(Lifecycle, Wrapper):
def __init__(self, wrapped, *, parent=None):
Lifecycle.__init__(self)
Wrapper.__init__(self, wrapped)
self.parent = parent

View File

@ -1,38 +1,44 @@
import logging import logging
import sys import sys
import warnings
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from types import GeneratorType from types import GeneratorType
from bonobo.constants import NOT_MODIFIED, BEGIN, END from bonobo.config import create_container
from bonobo.config.processors import ContextCurrifier
from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD
from bonobo.errors import InactiveReadableError, UnrecoverableError from bonobo.errors import InactiveReadableError, UnrecoverableError
from bonobo.execution.contexts.base import LoopingExecutionContext from bonobo.execution.contexts.base import BaseContext
from bonobo.structs.bags import Bag from bonobo.structs.bags import Bag
from bonobo.structs.inputs import Input from bonobo.structs.inputs import Input
from bonobo.structs.tokens import Token from bonobo.structs.tokens import Token
from bonobo.util import get_name, iserrorbag, isloopbackbag, isbag, istuple from bonobo.util import get_name, iserrorbag, isloopbackbag, isbag, istuple, isconfigurabletype
from bonobo.util.compat import deprecated_alias
from bonobo.util.statistics import WithStatistics from bonobo.util.statistics import WithStatistics
from mondrian import term
logger = logging.getLogger(__name__)
class NodeExecutionContext(WithStatistics, LoopingExecutionContext): class NodeExecutionContext(BaseContext, WithStatistics):
""" def __init__(self, wrapped, *, parent=None, services=None, _input=None, _outputs=None):
todo: make the counter dependant of parent context? BaseContext.__init__(self, wrapped, parent=parent)
"""
@property
def killed(self):
return self._killed
def __init__(self, wrapped, parent=None, services=None, _input=None, _outputs=None):
LoopingExecutionContext.__init__(self, wrapped, parent=parent, services=services)
WithStatistics.__init__(self, 'in', 'out', 'err', 'warn') WithStatistics.__init__(self, 'in', 'out', 'err', 'warn')
# Services: how we'll access external dependencies
if services:
if self.parent:
raise RuntimeError(
'Having services defined both in GraphExecutionContext and child NodeExecutionContext is not supported, for now.'
)
self.services = create_container(services)
else:
self.services = None
# Input / Output: how the wrapped node will communicate
self.input = _input or Input() self.input = _input or Input()
self.outputs = _outputs or [] self.outputs = _outputs or []
self._killed = False
# Stack: context decorators for the execution
self._stack = None
def __str__(self): def __str__(self):
return self.__name__ + self.get_statistics_as_string(prefix=' ') return self.__name__ + self.get_statistics_as_string(prefix=' ')
@ -41,14 +47,94 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
name, type_name = get_name(self), get_name(type(self)) name, type_name = get_name(self), get_name(type(self))
return '<{}({}{}){}>'.format(type_name, self.status, name, self.get_statistics_as_string(prefix=' ')) return '<{}({}{}){}>'.format(type_name, self.status, name, self.get_statistics_as_string(prefix=' '))
def get_flags_as_string(self): def start(self):
if self._defunct: super().start()
return term.red('[defunct]')
if self.killed: try:
return term.lightred('[killed]') self._stack = ContextCurrifier(self.wrapped, *self._get_initial_context())
if self.stopped: if isconfigurabletype(self.wrapped):
return term.lightblack('[done]') # Not normal to have a partially configured object here, so let's warn the user instead of having get into
return '' # the hard trouble of understanding that by himself.
raise TypeError(
'The Configurable should be fully instanciated by now, unfortunately I got a PartiallyConfigured object...'
)
self._stack.setup(self)
except Exception:
return self.fatal(sys.exc_info())
def loop(self):
logger.debug('Node loop starts for {!r}.'.format(self))
while self.should_loop:
try:
self.step()
except InactiveReadableError:
break
except Empty:
sleep(TICK_PERIOD) # XXX: How do we determine this constant?
continue
except UnrecoverableError:
self.handle_error(*sys.exc_info())
self.input.shutdown()
break
except Exception: # pylint: disable=broad-except
self.handle_error(*sys.exc_info())
except BaseException:
self.handle_error(*sys.exc_info())
break
logger.debug('Node loop ends for {!r}.'.format(self))
def step(self):
"""Runs a transformation callable with given args/kwargs and flush the result into the right
output channel."""
# Pull data
input_bag = self.get()
# Sent through the stack
try:
results = input_bag.apply(self._stack)
except Exception:
return self.handle_error(*sys.exc_info())
# self._exec_time += timer.duration
# Put data onto output channels
if isinstance(results, GeneratorType):
while True:
try:
# if kill flag was step, stop iterating.
if self._killed:
break
result = next(results)
except StopIteration:
# That's not an error, we're just done.
break
except Exception:
# Let's kill this loop, won't be able to generate next.
self.handle_error(*sys.exc_info())
break
else:
self.send(_resolve(input_bag, result))
elif results:
self.send(_resolve(input_bag, results))
else:
# case with no result, an execution went through anyway, use for stats.
# self._exec_count += 1
pass
def stop(self):
if self._stack:
self._stack.teardown()
super().stop()
def handle_error(self, exctype, exc, tb, *, level=logging.ERROR):
self.increment('err')
logging.getLogger(__name__).log(level, repr(self), exc_info=(exctype, exc, tb))
def fatal(self, exc_info, *, level=logging.CRITICAL):
super().fatal(exc_info, level=level)
self.input.shutdown()
def write(self, *messages): def write(self, *messages):
""" """
@ -64,9 +150,6 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
for _ in messages: for _ in messages:
self.step() self.step()
# XXX deprecated alias
recv = deprecated_alias('recv', write)
def send(self, value, _control=False): def send(self, value, _control=False):
""" """
Sends a message to all of this context's outputs. Sends a message to all of this context's outputs.
@ -86,89 +169,25 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
for output in self.outputs: for output in self.outputs:
output.put(value) output.put(value)
push = deprecated_alias('push', send) def get(self):
def get(self): # recv() ? input_data = self.receive()
""" """
Get from the queue first, then increment stats, so if Queue raise Timeout or Empty, stat won't be changed. Get from the queue first, then increment stats, so if Queue raise Timeout or Empty, stat won't be changed.
""" """
row = self.input.get(timeout=self.PERIOD) row = self.input.get() # XXX TIMEOUT ???
self.increment('in') self.increment('in')
return row return row
def should_loop(self): def _get_initial_context(self):
return not any((self.defunct, self.killed)) if self.parent:
return self.parent.services.args_for(self.wrapped)
def loop(self): if self.services:
while self.should_loop(): return self.services.args_for(self.wrapped)
try: return ()
self.step()
except InactiveReadableError:
break
except Empty:
sleep(self.PERIOD)
continue
except UnrecoverableError:
self.handle_error(*sys.exc_info())
self.input.shutdown()
break
except Exception: # pylint: disable=broad-except
self.handle_error(*sys.exc_info())
except BaseException:
self.handle_error(*sys.exc_info())
break
def step(self):
# Pull data from the first available input channel.
"""Runs a transformation callable with given args/kwargs and flush the result into the right
output channel."""
input_bag = self.get()
results = input_bag.apply(self._stack)
# self._exec_time += timer.duration
# Put data onto output channels
if isinstance(results, GeneratorType):
while True:
try:
# if kill flag was step, stop iterating.
if self._killed:
break
result = next(results)
except StopIteration:
break
else:
self.send(_resolve(input_bag, result))
elif results:
self.send(_resolve(input_bag, results))
else:
# case with no result, an execution went through anyway, use for stats.
# self._exec_count += 1
pass
def kill(self):
if not self.started:
raise RuntimeError('Cannot kill a node context that has not started yet.')
if self.stopped:
raise RuntimeError('Cannot kill a node context that has already stopped.')
self._killed = True
def as_dict(self):
return {
'status': self.status,
'name': self.name,
'stats': self.get_statistics_as_string(),
'flags': self.get_flags_as_string(),
}
def isflag(param): def isflag(param):
return isinstance(param, Token) and param in (NOT_MODIFIED,) return isinstance(param, Token) and param in (NOT_MODIFIED, )
def split_tokens(output): def split_tokens(output):
@ -180,11 +199,11 @@ def split_tokens(output):
""" """
if isinstance(output, Token): if isinstance(output, Token):
# just a flag # just a flag
return (output,), () return (output, ), ()
if not istuple(output): if not istuple(output):
# no flag # no flag
return (), (output,) return (), (output, )
i = 0 i = 0
while isflag(output[i]): while isflag(output[i]):

View File

@ -1,7 +1,7 @@
from bonobo.execution.contexts.base import LoopingExecutionContext from bonobo.execution.contexts.base import BaseContext
class PluginExecutionContext(LoopingExecutionContext): class PluginExecutionContext(BaseContext):
@property @property
def dispatcher(self): def dispatcher(self):
return self.parent.dispatcher return self.parent.dispatcher

View File

@ -52,15 +52,8 @@ class ExecutorStrategy(Strategy):
def starter(node): def starter(node):
@functools.wraps(node) @functools.wraps(node)
def _runner(): def _runner():
try: with node:
with node: node.loop()
node.loop()
except:
logging.getLogger(__name__).critical(
'Uncaught exception in node execution for {}.'.format(node), exc_info=True
)
node.shutdown()
node.stop()
try: try:
futures.append(executor.submit(_runner)) futures.append(executor.submit(_runner))

View File

@ -4,7 +4,7 @@ import itertools
from bonobo import settings from bonobo import settings
from bonobo.config import Configurable, Option from bonobo.config import Configurable, Option
from bonobo.config.processors import ContextProcessor from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED, ARGNAMES
from bonobo.structs.bags import Bag from bonobo.structs.bags import Bag
from bonobo.util.objects import ValueHolder from bonobo.util.objects import ValueHolder
from bonobo.util.term import CLEAR_EOL from bonobo.util.term import CLEAR_EOL
@ -88,18 +88,29 @@ class PrettyPrinter(Configurable):
def call(self, *args, **kwargs): def call(self, *args, **kwargs):
formater = self._format_quiet if settings.QUIET.get() else self._format_console formater = self._format_quiet if settings.QUIET.get() else self._format_console
argnames = kwargs.get(ARGNAMES, None)
for i, (item, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())): for i, (item, value) in enumerate(
print(formater(i, item, value)) itertools.chain(enumerate(args), filter(lambda x: not x[0].startswith('_'), kwargs.items()))
):
print(formater(i, item, value, argnames=argnames))
def _format_quiet(self, i, item, value): def _format_quiet(self, i, item, value, *, argnames=None):
# XXX should we implement argnames here ?
return ' '.join(((' ' if i else '-'), str(item), ':', str(value).strip())) return ' '.join(((' ' if i else '-'), str(item), ':', str(value).strip()))
def _format_console(self, i, item, value): def _format_console(self, i, item, value, *, argnames=None):
argnames = argnames or []
if not isinstance(item, str):
if len(argnames) >= item:
item = '{} / {}'.format(item, argnames[item])
else:
item = str(i)
return ' '.join( return ' '.join(
( (
(' ' if i else ''), str(item), '=', _shorten(str(value).strip(), (' ' if i else ''), item, '=', _shorten(str(value).strip(),
self.max_width).replace('\n', '\n' + CLEAR_EOL), CLEAR_EOL self.max_width).replace('\n', '\n' + CLEAR_EOL), CLEAR_EOL
) )
) )
@ -172,6 +183,3 @@ class FixedWindow(Configurable):
if len(buffer) >= self.length: if len(buffer) >= self.length:
yield buffer.get() yield buffer.get()
buffer.set([]) buffer.set([])

View File

@ -1,8 +1,8 @@
""" Readers and writers for common file formats. """ """ Readers and writers for common file formats. """
from .csv import CsvReader, CsvWriter
from .file import FileReader, FileWriter from .file import FileReader, FileWriter
from .json import JsonReader, JsonWriter, LdjsonReader, LdjsonWriter from .json import JsonReader, JsonWriter, LdjsonReader, LdjsonWriter
from .csv import CsvReader, CsvWriter
from .pickle import PickleReader, PickleWriter from .pickle import PickleReader, PickleWriter
__all__ = [ __all__ = [

View File

@ -1,7 +1,4 @@
from fs.errors import ResourceNotFound
from bonobo.config import Configurable, ContextProcessor, Option, Service from bonobo.config import Configurable, ContextProcessor, Option, Service
from bonobo.errors import UnrecoverableError
class FileHandler(Configurable): class FileHandler(Configurable):

View File

@ -1,13 +1,13 @@
import csv import csv
import warnings import warnings
from bonobo.config import Option from bonobo.config import Option, ContextProcessor
from bonobo.config.options import RemovedOption from bonobo.config.options import RemovedOption, Method
from bonobo.config.processors import ContextProcessor from bonobo.constants import NOT_MODIFIED, ARGNAMES
from bonobo.constants import NOT_MODIFIED
from bonobo.nodes.io.base import FileHandler from bonobo.nodes.io.base import FileHandler
from bonobo.nodes.io.file import FileReader, FileWriter from bonobo.nodes.io.file import FileReader, FileWriter
from bonobo.util.objects import ValueHolder from bonobo.structs.bags import Bag
from bonobo.util import ensure_tuple
class CsvHandler(FileHandler): class CsvHandler(FileHandler):
@ -28,7 +28,7 @@ class CsvHandler(FileHandler):
""" """
delimiter = Option(str, default=';') delimiter = Option(str, default=';')
quotechar = Option(str, default='"') quotechar = Option(str, default='"')
headers = Option(tuple, required=False) headers = Option(ensure_tuple, required=False)
ioformat = RemovedOption(positional=False, value='kwargs') ioformat = RemovedOption(positional=False, value='kwargs')
@ -44,41 +44,66 @@ class CsvReader(FileReader, CsvHandler):
skip = Option(int, default=0) skip = Option(int, default=0)
@ContextProcessor @Method(
def csv_headers(self, context, fs, file): __doc__='''
yield ValueHolder(self.headers) Builds the CSV reader, a.k.a an object we can iterate, each iteration giving one line of fields, as an
iterable.
Defaults to builtin csv.reader(...), but can be overriden to fit your special needs.
'''
)
def reader_factory(self, file):
return csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
def read(self, fs, file, headers): def read(self, fs, file):
reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar) reader = self.reader_factory(file)
headers = self.headers or next(reader)
if not headers.get(): for row in reader:
headers.set(next(reader)) yield Bag(*row, **{ARGNAMES: headers})
_headers = headers.get()
field_count = len(headers)
if self.skip and self.skip > 0:
for _ in range(0, self.skip):
next(reader)
for lineno, row in enumerate(reader):
if len(row) != field_count:
warnings.warn('Got %d fields on line #%d, expecting %d.' % (len(row), lineno, field_count,))
yield dict(zip(_headers, row))
class CsvWriter(FileWriter, CsvHandler): class CsvWriter(FileWriter, CsvHandler):
@ContextProcessor @ContextProcessor
def writer(self, context, fs, file, lineno): def context(self, context, *args):
writer = csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol) yield context
headers = ValueHolder(list(self.headers) if self.headers else None)
yield writer, headers @Method(
__doc__='''
Builds the CSV writer, a.k.a an object we can pass a field collection to be written as one line in the
target file.
Defaults to builtin csv.writer(...).writerow, but can be overriden to fit your special needs.
'''
)
def writer_factory(self, file):
return csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol).writerow
def write(self, fs, file, lineno, context, *args, _argnames=None):
try:
writer = context.writer
except AttributeError:
context.writer = self.writer_factory(file)
writer = context.writer
context.headers = self.headers or _argnames
if context.headers and not lineno:
writer(context.headers)
def write(self, fs, file, lineno, writer, headers, **row):
if not lineno:
headers.set(headers.value or row.keys())
writer.writerow(headers.get())
writer.writerow(row.get(header, '') for header in headers.get())
lineno += 1 lineno += 1
if context.headers:
try:
row = [args[i] for i, header in enumerate(context.headers) if header]
except IndexError:
warnings.warn(
'At line #{}, expected {} fields but only got {}. Padding with empty strings.'.format(
lineno, len(context.headers), len(args)
)
)
row = [(args[i] if i < len(args) else '') for i, header in enumerate(context.headers) if header]
else:
row = args
writer(row)
return NOT_MODIFIED return NOT_MODIFIED

View File

@ -1,7 +1,6 @@
import pickle import pickle
from bonobo.config import Option from bonobo.config import Option
from bonobo.config.options import RemovedOption
from bonobo.config.processors import ContextProcessor from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED
from bonobo.nodes.io.base import FileHandler from bonobo.nodes.io.base import FileHandler

View File

@ -52,8 +52,9 @@ class Bag:
# Otherwise, type will handle that for us. # Otherwise, type will handle that for us.
return super().__new__(cls) return super().__new__(cls)
def __init__(self, *args, _flags=None, _parent=None, **kwargs): def __init__(self, *args, _flags=None, _parent=None, _argnames=None, **kwargs):
self._flags = type(self).default_flags + (_flags or ()) self._flags = type(self).default_flags + (_flags or ())
self._argnames = _argnames
self._parent = _parent self._parent = _parent
if len(args) == 1 and len(kwargs) == 0: if len(args) == 1 and len(kwargs) == 0:
@ -115,9 +116,13 @@ class Bag:
def flags(self): def flags(self):
return self._flags return self._flags
@property
def specials(self):
return {k: self.__dict__[k] for k in ('_argnames', ) if k in self.__dict__ and self.__dict__[k]}
def apply(self, func_or_iter, *args, **kwargs): def apply(self, func_or_iter, *args, **kwargs):
if callable(func_or_iter): if callable(func_or_iter):
return func_or_iter(*args, *self.args, **kwargs, **self.kwargs) return func_or_iter(*args, *self.args, **kwargs, **self.kwargs, **self.specials)
if len(args) == 0 and len(kwargs) == 0: if len(args) == 0 and len(kwargs) == 0:
try: try:
@ -148,7 +153,7 @@ class Bag:
@classmethod @classmethod
def inherit(cls, *args, **kwargs): def inherit(cls, *args, **kwargs):
return cls(*args, _flags=(INHERIT_INPUT,), **kwargs) return cls(*args, _flags=(INHERIT_INPUT, ), **kwargs)
def __eq__(self, other): def __eq__(self, other):
# XXX there are overlapping cases, but this is very handy for now. Let's think about it later. # XXX there are overlapping cases, but this is very handy for now. Let's think about it later.
@ -176,9 +181,12 @@ class Bag:
return len(self.args) == 1 and not self.kwargs and self.args[0] == other return len(self.args) == 1 and not self.kwargs and self.args[0] == other
def args_as_dict(self):
return dict(zip(self._argnames, self.args))
class LoopbackBag(Bag): class LoopbackBag(Bag):
default_flags = (LOOPBACK,) default_flags = (LOOPBACK, )
class ErrorBag(Bag): class ErrorBag(Bag):

View File

@ -16,6 +16,8 @@ def ensure_tuple(tuple_or_mixed):
:return: tuple :return: tuple
""" """
if tuple_or_mixed is None:
return ()
if isinstance(tuple_or_mixed, tuple): if isinstance(tuple_or_mixed, tuple):
return tuple_or_mixed return tuple_or_mixed
return (tuple_or_mixed, ) return (tuple_or_mixed, )

View File

@ -8,7 +8,7 @@ from unittest.mock import patch
import pytest import pytest
from bonobo import open_fs, Token, __main__, get_examples_path from bonobo import open_fs, Token, __main__, get_examples_path, Bag
from bonobo.commands import entrypoint from bonobo.commands import entrypoint
from bonobo.execution.contexts.graph import GraphExecutionContext from bonobo.execution.contexts.graph import GraphExecutionContext
from bonobo.execution.contexts.node import NodeExecutionContext from bonobo.execution.contexts.node import NodeExecutionContext
@ -57,6 +57,9 @@ class BufferingContext:
def get_buffer(self): def get_buffer(self):
return self.buffer return self.buffer
def get_buffer_args_as_dicts(self):
return list(map(lambda x: x.args_as_dict() if isinstance(x, Bag) else dict(x), self.buffer))
class BufferingNodeExecutionContext(BufferingContext, NodeExecutionContext): class BufferingNodeExecutionContext(BufferingContext, NodeExecutionContext):
def __init__(self, *args, buffer=None, **kwargs): def __init__(self, *args, buffer=None, **kwargs):

View File

@ -58,7 +58,7 @@ def test_define_with_decorator():
Concrete = MethodBasedConfigurable(my_handler) Concrete = MethodBasedConfigurable(my_handler)
assert callable(Concrete.handler) assert callable(Concrete.handler)
assert Concrete.handler == my_handler assert Concrete.handler.__func__ == my_handler
with inspect_node(Concrete) as ci: with inspect_node(Concrete) as ci:
assert ci.type == MethodBasedConfigurable assert ci.type == MethodBasedConfigurable

View File

@ -17,17 +17,21 @@ def test_write_csv_ioformat_arg0(tmpdir):
CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0), CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0),
@pytest.mark.parametrize('add_kwargs', ( def test_write_csv_to_file_no_headers(tmpdir):
{},
{
'ioformat': settings.IOFORMAT_KWARGS,
},
))
def test_write_csv_to_file_kwargs(tmpdir, add_kwargs):
fs, filename, services = csv_tester.get_services_for_writer(tmpdir) fs, filename, services = csv_tester.get_services_for_writer(tmpdir)
with NodeExecutionContext(CsvWriter(filename, **add_kwargs), services=services) as context: with NodeExecutionContext(CsvWriter(filename), services=services) as context:
context.write_sync({'foo': 'bar'}, {'foo': 'baz', 'ignore': 'this'}) context.write_sync(('bar', ), ('baz', 'boo'))
with fs.open(filename) as fp:
assert fp.read() == 'bar\nbaz;boo\n'
def test_write_csv_to_file_with_headers(tmpdir):
fs, filename, services = csv_tester.get_services_for_writer(tmpdir)
with NodeExecutionContext(CsvWriter(filename, headers='foo'), services=services) as context:
context.write_sync(('bar', ), ('baz', 'boo'))
with fs.open(filename) as fp: with fs.open(filename) as fp:
assert fp.read() == 'foo\nbar\nbaz\n' assert fp.read() == 'foo\nbar\nbaz\n'
@ -45,7 +49,7 @@ def test_read_csv_from_file_kwargs(tmpdir):
) as context: ) as context:
context.write_sync(()) context.write_sync(())
assert context.get_buffer() == [ assert context.get_buffer_args_as_dicts() == [
{ {
'a': 'a foo', 'a': 'a foo',
'b': 'b foo', 'b': 'b foo',