Core: refactoring contexts with more logical responsibilities, stopping to rely on kargs ordering for compat with python3.5
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -11,6 +11,7 @@
|
||||
*.so
|
||||
*.spec
|
||||
.*.sw?
|
||||
.DS_Store
|
||||
.Python
|
||||
.cache
|
||||
.coverage
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import types
|
||||
|
||||
from bonobo.util.inspect import istype
|
||||
|
||||
|
||||
@ -143,10 +145,23 @@ class Method(Option):
|
||||
|
||||
>>> example3 = OtherChildMethodExample()
|
||||
|
||||
It's possible to pass a default implementation to a Method by calling it, making it suitable to use as a decorator.
|
||||
|
||||
>>> class MethodExampleWithDefault(Configurable):
|
||||
... @Method()
|
||||
... def handler(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, required=True, positional=True):
|
||||
super().__init__(None, required=required, positional=positional)
|
||||
def __init__(self, *, required=True, positional=True, __doc__=None):
|
||||
super().__init__(None, required=required, positional=positional, __doc__=__doc__)
|
||||
|
||||
def __get__(self, inst, type_):
|
||||
x = super(Method, self).__get__(inst, type_)
|
||||
if inst:
|
||||
x = types.MethodType(x, inst)
|
||||
return x
|
||||
|
||||
def __set__(self, inst, value):
|
||||
if not hasattr(value, '__call__'):
|
||||
@ -157,6 +172,12 @@ class Method(Option):
|
||||
)
|
||||
inst._options_values[self.name] = self.type(value) if self.type else value
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# only here to trick IDEs into thinking this is callable.
|
||||
raise NotImplementedError('You cannot call the descriptor')
|
||||
def __call__(self, impl):
|
||||
if self.default:
|
||||
raise RuntimeError('Can only be used once as a decorator.')
|
||||
self.default = impl
|
||||
self.required = False
|
||||
return self
|
||||
|
||||
def get_default(self):
|
||||
return self.default
|
||||
|
||||
@ -7,3 +7,7 @@ LOOPBACK = Token('Loopback')
|
||||
NOT_MODIFIED = Token('NotModified')
|
||||
DEFAULT_SERVICES_FILENAME = '_services.py'
|
||||
DEFAULT_SERVICES_ATTR = 'get_services'
|
||||
|
||||
TICK_PERIOD = 0.2
|
||||
|
||||
ARGNAMES = '_argnames'
|
||||
|
||||
@ -1,14 +1,10 @@
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from logging import WARNING, ERROR
|
||||
from logging import ERROR
|
||||
|
||||
import mondrian
|
||||
from bonobo.config import create_container
|
||||
from bonobo.config.processors import ContextCurrifier
|
||||
from bonobo.execution import logger
|
||||
from bonobo.util import isconfigurabletype
|
||||
from bonobo.util.objects import Wrapper, get_name
|
||||
from mondrian import term
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -28,8 +24,12 @@ def unrecoverable(error_handler):
|
||||
raise # raise unrecoverableerror from x ?
|
||||
|
||||
|
||||
class LoopingExecutionContext(Wrapper):
|
||||
PERIOD = 0.5
|
||||
class Lifecycle:
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._stopped = False
|
||||
self._killed = False
|
||||
self._defunct = False
|
||||
|
||||
@property
|
||||
def started(self):
|
||||
@ -39,6 +39,10 @@ class LoopingExecutionContext(Wrapper):
|
||||
def stopped(self):
|
||||
return self._stopped
|
||||
|
||||
@property
|
||||
def killed(self):
|
||||
return self._killed
|
||||
|
||||
@property
|
||||
def defunct(self):
|
||||
return self._defunct
|
||||
@ -47,6 +51,11 @@ class LoopingExecutionContext(Wrapper):
|
||||
def alive(self):
|
||||
return self._started and not self._stopped
|
||||
|
||||
@property
|
||||
def should_loop(self):
|
||||
# TODO XXX started/stopped?
|
||||
return not any((self.defunct, self.killed))
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
"""One character status for this node. """
|
||||
@ -58,23 +67,6 @@ class LoopingExecutionContext(Wrapper):
|
||||
return '+'
|
||||
return '-'
|
||||
|
||||
def __init__(self, wrapped, parent, services=None):
|
||||
super().__init__(wrapped)
|
||||
|
||||
self.parent = parent
|
||||
|
||||
if services:
|
||||
if parent:
|
||||
raise RuntimeError(
|
||||
'Having services defined both in GraphExecutionContext and child NodeExecutionContext is not supported, for now.'
|
||||
)
|
||||
self.services = create_container(services)
|
||||
else:
|
||||
self.services = None
|
||||
|
||||
self._started, self._stopped, self._defunct = False, False, False
|
||||
self._stack = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
@ -82,57 +74,54 @@ class LoopingExecutionContext(Wrapper):
|
||||
def __exit__(self, exc_type=None, exc_val=None, exc_tb=None):
|
||||
self.stop()
|
||||
|
||||
def get_flags_as_string(self):
|
||||
if self._defunct:
|
||||
return term.red('[defunct]')
|
||||
if self.killed:
|
||||
return term.lightred('[killed]')
|
||||
if self.stopped:
|
||||
return term.lightblack('[done]')
|
||||
return ''
|
||||
|
||||
def start(self):
|
||||
if self.started:
|
||||
raise RuntimeError('Cannot start a node twice ({}).'.format(get_name(self)))
|
||||
raise RuntimeError('This context is already started ({}).'.format(get_name(self)))
|
||||
|
||||
self._started = True
|
||||
|
||||
try:
|
||||
self._stack = ContextCurrifier(self.wrapped, *self._get_initial_context())
|
||||
if isconfigurabletype(self.wrapped):
|
||||
# Not normal to have a partially configured object here, so let's warn the user instead of having get into
|
||||
# the hard trouble of understanding that by himself.
|
||||
raise TypeError(
|
||||
'The Configurable should be fully instanciated by now, unfortunately I got a PartiallyConfigured object...'
|
||||
)
|
||||
self._stack.setup(self)
|
||||
except Exception:
|
||||
return self.fatal(sys.exc_info())
|
||||
|
||||
def loop(self):
|
||||
"""Generic loop. A bit boring. """
|
||||
while self.alive:
|
||||
self.step()
|
||||
|
||||
def step(self):
|
||||
"""Left as an exercise for the children."""
|
||||
raise NotImplementedError('Abstract.')
|
||||
|
||||
def stop(self):
|
||||
if not self.started:
|
||||
raise RuntimeError('Cannot stop an unstarted node ({}).'.format(get_name(self)))
|
||||
raise RuntimeError('This context cannot be stopped as it never started ({}).'.format(get_name(self)))
|
||||
|
||||
if self._stopped:
|
||||
self._stopped = True
|
||||
|
||||
if self._stopped: # Stopping twice has no effect
|
||||
return
|
||||
|
||||
try:
|
||||
if self._stack:
|
||||
self._stack.teardown()
|
||||
finally:
|
||||
self._stopped = True
|
||||
def kill(self):
|
||||
if not self.started:
|
||||
raise RuntimeError('Cannot kill an unstarted context.')
|
||||
|
||||
def _get_initial_context(self):
|
||||
if self.parent:
|
||||
return self.parent.services.args_for(self.wrapped)
|
||||
if self.services:
|
||||
return self.services.args_for(self.wrapped)
|
||||
return ()
|
||||
if self.stopped:
|
||||
raise RuntimeError('Cannot kill a stopped context.')
|
||||
|
||||
def handle_error(self, exctype, exc, tb, *, level=logging.ERROR):
|
||||
logging.getLogger(__name__).log(level, repr(self), exc_info=(exctype, exc, tb))
|
||||
self._killed = True
|
||||
|
||||
def fatal(self, exc_info):
|
||||
def fatal(self, exc_info, *, level=logging.CRITICAL):
|
||||
logging.getLogger(__name__).log(level, repr(self), exc_info=exc_info)
|
||||
self._defunct = True
|
||||
self.input.shutdown()
|
||||
self.handle_error(*exc_info, level=logging.CRITICAL)
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
'status': self.status,
|
||||
'name': self.name,
|
||||
'stats': self.get_statistics_as_string(),
|
||||
'flags': self.get_flags_as_string(),
|
||||
}
|
||||
|
||||
|
||||
class BaseContext(Lifecycle, Wrapper):
|
||||
def __init__(self, wrapped, *, parent=None):
|
||||
Lifecycle.__init__(self)
|
||||
Wrapper.__init__(self, wrapped)
|
||||
self.parent = parent
|
||||
|
||||
@ -1,38 +1,44 @@
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from queue import Empty
|
||||
from time import sleep
|
||||
from types import GeneratorType
|
||||
|
||||
from bonobo.constants import NOT_MODIFIED, BEGIN, END
|
||||
from bonobo.config import create_container
|
||||
from bonobo.config.processors import ContextCurrifier
|
||||
from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD
|
||||
from bonobo.errors import InactiveReadableError, UnrecoverableError
|
||||
from bonobo.execution.contexts.base import LoopingExecutionContext
|
||||
from bonobo.execution.contexts.base import BaseContext
|
||||
from bonobo.structs.bags import Bag
|
||||
from bonobo.structs.inputs import Input
|
||||
from bonobo.structs.tokens import Token
|
||||
from bonobo.util import get_name, iserrorbag, isloopbackbag, isbag, istuple
|
||||
from bonobo.util.compat import deprecated_alias
|
||||
from bonobo.util import get_name, iserrorbag, isloopbackbag, isbag, istuple, isconfigurabletype
|
||||
from bonobo.util.statistics import WithStatistics
|
||||
from mondrian import term
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
|
||||
"""
|
||||
todo: make the counter dependant of parent context?
|
||||
"""
|
||||
|
||||
@property
|
||||
def killed(self):
|
||||
return self._killed
|
||||
|
||||
def __init__(self, wrapped, parent=None, services=None, _input=None, _outputs=None):
|
||||
LoopingExecutionContext.__init__(self, wrapped, parent=parent, services=services)
|
||||
class NodeExecutionContext(BaseContext, WithStatistics):
|
||||
def __init__(self, wrapped, *, parent=None, services=None, _input=None, _outputs=None):
|
||||
BaseContext.__init__(self, wrapped, parent=parent)
|
||||
WithStatistics.__init__(self, 'in', 'out', 'err', 'warn')
|
||||
|
||||
# Services: how we'll access external dependencies
|
||||
if services:
|
||||
if self.parent:
|
||||
raise RuntimeError(
|
||||
'Having services defined both in GraphExecutionContext and child NodeExecutionContext is not supported, for now.'
|
||||
)
|
||||
self.services = create_container(services)
|
||||
else:
|
||||
self.services = None
|
||||
|
||||
# Input / Output: how the wrapped node will communicate
|
||||
self.input = _input or Input()
|
||||
self.outputs = _outputs or []
|
||||
self._killed = False
|
||||
|
||||
# Stack: context decorators for the execution
|
||||
self._stack = None
|
||||
|
||||
def __str__(self):
|
||||
return self.__name__ + self.get_statistics_as_string(prefix=' ')
|
||||
@ -41,14 +47,94 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
|
||||
name, type_name = get_name(self), get_name(type(self))
|
||||
return '<{}({}{}){}>'.format(type_name, self.status, name, self.get_statistics_as_string(prefix=' '))
|
||||
|
||||
def get_flags_as_string(self):
|
||||
if self._defunct:
|
||||
return term.red('[defunct]')
|
||||
if self.killed:
|
||||
return term.lightred('[killed]')
|
||||
if self.stopped:
|
||||
return term.lightblack('[done]')
|
||||
return ''
|
||||
def start(self):
|
||||
super().start()
|
||||
|
||||
try:
|
||||
self._stack = ContextCurrifier(self.wrapped, *self._get_initial_context())
|
||||
if isconfigurabletype(self.wrapped):
|
||||
# Not normal to have a partially configured object here, so let's warn the user instead of having get into
|
||||
# the hard trouble of understanding that by himself.
|
||||
raise TypeError(
|
||||
'The Configurable should be fully instanciated by now, unfortunately I got a PartiallyConfigured object...'
|
||||
)
|
||||
self._stack.setup(self)
|
||||
except Exception:
|
||||
return self.fatal(sys.exc_info())
|
||||
|
||||
def loop(self):
|
||||
logger.debug('Node loop starts for {!r}.'.format(self))
|
||||
while self.should_loop:
|
||||
try:
|
||||
self.step()
|
||||
except InactiveReadableError:
|
||||
break
|
||||
except Empty:
|
||||
sleep(TICK_PERIOD) # XXX: How do we determine this constant?
|
||||
continue
|
||||
except UnrecoverableError:
|
||||
self.handle_error(*sys.exc_info())
|
||||
self.input.shutdown()
|
||||
break
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.handle_error(*sys.exc_info())
|
||||
except BaseException:
|
||||
self.handle_error(*sys.exc_info())
|
||||
break
|
||||
logger.debug('Node loop ends for {!r}.'.format(self))
|
||||
|
||||
def step(self):
|
||||
"""Runs a transformation callable with given args/kwargs and flush the result into the right
|
||||
output channel."""
|
||||
|
||||
# Pull data
|
||||
input_bag = self.get()
|
||||
|
||||
# Sent through the stack
|
||||
try:
|
||||
results = input_bag.apply(self._stack)
|
||||
except Exception:
|
||||
return self.handle_error(*sys.exc_info())
|
||||
|
||||
# self._exec_time += timer.duration
|
||||
# Put data onto output channels
|
||||
|
||||
if isinstance(results, GeneratorType):
|
||||
while True:
|
||||
try:
|
||||
# if kill flag was step, stop iterating.
|
||||
if self._killed:
|
||||
break
|
||||
result = next(results)
|
||||
except StopIteration:
|
||||
# That's not an error, we're just done.
|
||||
break
|
||||
except Exception:
|
||||
# Let's kill this loop, won't be able to generate next.
|
||||
self.handle_error(*sys.exc_info())
|
||||
break
|
||||
else:
|
||||
self.send(_resolve(input_bag, result))
|
||||
elif results:
|
||||
self.send(_resolve(input_bag, results))
|
||||
else:
|
||||
# case with no result, an execution went through anyway, use for stats.
|
||||
# self._exec_count += 1
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
if self._stack:
|
||||
self._stack.teardown()
|
||||
|
||||
super().stop()
|
||||
|
||||
def handle_error(self, exctype, exc, tb, *, level=logging.ERROR):
|
||||
self.increment('err')
|
||||
logging.getLogger(__name__).log(level, repr(self), exc_info=(exctype, exc, tb))
|
||||
|
||||
def fatal(self, exc_info, *, level=logging.CRITICAL):
|
||||
super().fatal(exc_info, level=level)
|
||||
self.input.shutdown()
|
||||
|
||||
def write(self, *messages):
|
||||
"""
|
||||
@ -64,9 +150,6 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
|
||||
for _ in messages:
|
||||
self.step()
|
||||
|
||||
# XXX deprecated alias
|
||||
recv = deprecated_alias('recv', write)
|
||||
|
||||
def send(self, value, _control=False):
|
||||
"""
|
||||
Sends a message to all of this context's outputs.
|
||||
@ -86,89 +169,25 @@ class NodeExecutionContext(WithStatistics, LoopingExecutionContext):
|
||||
for output in self.outputs:
|
||||
output.put(value)
|
||||
|
||||
push = deprecated_alias('push', send)
|
||||
|
||||
def get(self): # recv() ? input_data = self.receive()
|
||||
def get(self):
|
||||
"""
|
||||
Get from the queue first, then increment stats, so if Queue raise Timeout or Empty, stat won't be changed.
|
||||
|
||||
"""
|
||||
row = self.input.get(timeout=self.PERIOD)
|
||||
row = self.input.get() # XXX TIMEOUT ???
|
||||
self.increment('in')
|
||||
return row
|
||||
|
||||
def should_loop(self):
|
||||
return not any((self.defunct, self.killed))
|
||||
|
||||
def loop(self):
|
||||
while self.should_loop():
|
||||
try:
|
||||
self.step()
|
||||
except InactiveReadableError:
|
||||
break
|
||||
except Empty:
|
||||
sleep(self.PERIOD)
|
||||
continue
|
||||
except UnrecoverableError:
|
||||
self.handle_error(*sys.exc_info())
|
||||
self.input.shutdown()
|
||||
break
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.handle_error(*sys.exc_info())
|
||||
except BaseException:
|
||||
self.handle_error(*sys.exc_info())
|
||||
break
|
||||
|
||||
def step(self):
|
||||
# Pull data from the first available input channel.
|
||||
"""Runs a transformation callable with given args/kwargs and flush the result into the right
|
||||
output channel."""
|
||||
|
||||
input_bag = self.get()
|
||||
|
||||
results = input_bag.apply(self._stack)
|
||||
|
||||
# self._exec_time += timer.duration
|
||||
# Put data onto output channels
|
||||
|
||||
if isinstance(results, GeneratorType):
|
||||
while True:
|
||||
try:
|
||||
# if kill flag was step, stop iterating.
|
||||
if self._killed:
|
||||
break
|
||||
result = next(results)
|
||||
except StopIteration:
|
||||
break
|
||||
else:
|
||||
self.send(_resolve(input_bag, result))
|
||||
elif results:
|
||||
self.send(_resolve(input_bag, results))
|
||||
else:
|
||||
# case with no result, an execution went through anyway, use for stats.
|
||||
# self._exec_count += 1
|
||||
pass
|
||||
|
||||
def kill(self):
|
||||
if not self.started:
|
||||
raise RuntimeError('Cannot kill a node context that has not started yet.')
|
||||
|
||||
if self.stopped:
|
||||
raise RuntimeError('Cannot kill a node context that has already stopped.')
|
||||
|
||||
self._killed = True
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
'status': self.status,
|
||||
'name': self.name,
|
||||
'stats': self.get_statistics_as_string(),
|
||||
'flags': self.get_flags_as_string(),
|
||||
}
|
||||
def _get_initial_context(self):
|
||||
if self.parent:
|
||||
return self.parent.services.args_for(self.wrapped)
|
||||
if self.services:
|
||||
return self.services.args_for(self.wrapped)
|
||||
return ()
|
||||
|
||||
|
||||
def isflag(param):
|
||||
return isinstance(param, Token) and param in (NOT_MODIFIED,)
|
||||
return isinstance(param, Token) and param in (NOT_MODIFIED, )
|
||||
|
||||
|
||||
def split_tokens(output):
|
||||
@ -180,11 +199,11 @@ def split_tokens(output):
|
||||
"""
|
||||
if isinstance(output, Token):
|
||||
# just a flag
|
||||
return (output,), ()
|
||||
return (output, ), ()
|
||||
|
||||
if not istuple(output):
|
||||
# no flag
|
||||
return (), (output,)
|
||||
return (), (output, )
|
||||
|
||||
i = 0
|
||||
while isflag(output[i]):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from bonobo.execution.contexts.base import LoopingExecutionContext
|
||||
from bonobo.execution.contexts.base import BaseContext
|
||||
|
||||
|
||||
class PluginExecutionContext(LoopingExecutionContext):
|
||||
class PluginExecutionContext(BaseContext):
|
||||
@property
|
||||
def dispatcher(self):
|
||||
return self.parent.dispatcher
|
||||
|
||||
@ -52,15 +52,8 @@ class ExecutorStrategy(Strategy):
|
||||
def starter(node):
|
||||
@functools.wraps(node)
|
||||
def _runner():
|
||||
try:
|
||||
with node:
|
||||
node.loop()
|
||||
except:
|
||||
logging.getLogger(__name__).critical(
|
||||
'Uncaught exception in node execution for {}.'.format(node), exc_info=True
|
||||
)
|
||||
node.shutdown()
|
||||
node.stop()
|
||||
with node:
|
||||
node.loop()
|
||||
|
||||
try:
|
||||
futures.append(executor.submit(_runner))
|
||||
|
||||
@ -4,7 +4,7 @@ import itertools
|
||||
from bonobo import settings
|
||||
from bonobo.config import Configurable, Option
|
||||
from bonobo.config.processors import ContextProcessor
|
||||
from bonobo.constants import NOT_MODIFIED
|
||||
from bonobo.constants import NOT_MODIFIED, ARGNAMES
|
||||
from bonobo.structs.bags import Bag
|
||||
from bonobo.util.objects import ValueHolder
|
||||
from bonobo.util.term import CLEAR_EOL
|
||||
@ -88,18 +88,29 @@ class PrettyPrinter(Configurable):
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
formater = self._format_quiet if settings.QUIET.get() else self._format_console
|
||||
argnames = kwargs.get(ARGNAMES, None)
|
||||
|
||||
for i, (item, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())):
|
||||
print(formater(i, item, value))
|
||||
for i, (item, value) in enumerate(
|
||||
itertools.chain(enumerate(args), filter(lambda x: not x[0].startswith('_'), kwargs.items()))
|
||||
):
|
||||
print(formater(i, item, value, argnames=argnames))
|
||||
|
||||
def _format_quiet(self, i, item, value):
|
||||
def _format_quiet(self, i, item, value, *, argnames=None):
|
||||
# XXX should we implement argnames here ?
|
||||
return ' '.join(((' ' if i else '-'), str(item), ':', str(value).strip()))
|
||||
|
||||
def _format_console(self, i, item, value):
|
||||
def _format_console(self, i, item, value, *, argnames=None):
|
||||
argnames = argnames or []
|
||||
if not isinstance(item, str):
|
||||
if len(argnames) >= item:
|
||||
item = '{} / {}'.format(item, argnames[item])
|
||||
else:
|
||||
item = str(i)
|
||||
|
||||
return ' '.join(
|
||||
(
|
||||
(' ' if i else '•'), str(item), '=', _shorten(str(value).strip(),
|
||||
self.max_width).replace('\n', '\n' + CLEAR_EOL), CLEAR_EOL
|
||||
(' ' if i else '•'), item, '=', _shorten(str(value).strip(),
|
||||
self.max_width).replace('\n', '\n' + CLEAR_EOL), CLEAR_EOL
|
||||
)
|
||||
)
|
||||
|
||||
@ -172,6 +183,3 @@ class FixedWindow(Configurable):
|
||||
if len(buffer) >= self.length:
|
||||
yield buffer.get()
|
||||
buffer.set([])
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
""" Readers and writers for common file formats. """
|
||||
|
||||
from .csv import CsvReader, CsvWriter
|
||||
from .file import FileReader, FileWriter
|
||||
from .json import JsonReader, JsonWriter, LdjsonReader, LdjsonWriter
|
||||
from .csv import CsvReader, CsvWriter
|
||||
from .pickle import PickleReader, PickleWriter
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
from fs.errors import ResourceNotFound
|
||||
|
||||
from bonobo.config import Configurable, ContextProcessor, Option, Service
|
||||
from bonobo.errors import UnrecoverableError
|
||||
|
||||
|
||||
class FileHandler(Configurable):
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import csv
|
||||
import warnings
|
||||
|
||||
from bonobo.config import Option
|
||||
from bonobo.config.options import RemovedOption
|
||||
from bonobo.config.processors import ContextProcessor
|
||||
from bonobo.constants import NOT_MODIFIED
|
||||
from bonobo.config import Option, ContextProcessor
|
||||
from bonobo.config.options import RemovedOption, Method
|
||||
from bonobo.constants import NOT_MODIFIED, ARGNAMES
|
||||
from bonobo.nodes.io.base import FileHandler
|
||||
from bonobo.nodes.io.file import FileReader, FileWriter
|
||||
from bonobo.util.objects import ValueHolder
|
||||
from bonobo.structs.bags import Bag
|
||||
from bonobo.util import ensure_tuple
|
||||
|
||||
|
||||
class CsvHandler(FileHandler):
|
||||
@ -28,7 +28,7 @@ class CsvHandler(FileHandler):
|
||||
"""
|
||||
delimiter = Option(str, default=';')
|
||||
quotechar = Option(str, default='"')
|
||||
headers = Option(tuple, required=False)
|
||||
headers = Option(ensure_tuple, required=False)
|
||||
ioformat = RemovedOption(positional=False, value='kwargs')
|
||||
|
||||
|
||||
@ -44,41 +44,66 @@ class CsvReader(FileReader, CsvHandler):
|
||||
|
||||
skip = Option(int, default=0)
|
||||
|
||||
@ContextProcessor
|
||||
def csv_headers(self, context, fs, file):
|
||||
yield ValueHolder(self.headers)
|
||||
@Method(
|
||||
__doc__='''
|
||||
Builds the CSV reader, a.k.a an object we can iterate, each iteration giving one line of fields, as an
|
||||
iterable.
|
||||
|
||||
def read(self, fs, file, headers):
|
||||
reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
|
||||
Defaults to builtin csv.reader(...), but can be overriden to fit your special needs.
|
||||
'''
|
||||
)
|
||||
def reader_factory(self, file):
|
||||
return csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
|
||||
|
||||
if not headers.get():
|
||||
headers.set(next(reader))
|
||||
_headers = headers.get()
|
||||
|
||||
field_count = len(headers)
|
||||
|
||||
if self.skip and self.skip > 0:
|
||||
for _ in range(0, self.skip):
|
||||
next(reader)
|
||||
|
||||
for lineno, row in enumerate(reader):
|
||||
if len(row) != field_count:
|
||||
warnings.warn('Got %d fields on line #%d, expecting %d.' % (len(row), lineno, field_count,))
|
||||
|
||||
yield dict(zip(_headers, row))
|
||||
def read(self, fs, file):
|
||||
reader = self.reader_factory(file)
|
||||
headers = self.headers or next(reader)
|
||||
for row in reader:
|
||||
yield Bag(*row, **{ARGNAMES: headers})
|
||||
|
||||
|
||||
class CsvWriter(FileWriter, CsvHandler):
|
||||
@ContextProcessor
|
||||
def writer(self, context, fs, file, lineno):
|
||||
writer = csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol)
|
||||
headers = ValueHolder(list(self.headers) if self.headers else None)
|
||||
yield writer, headers
|
||||
def context(self, context, *args):
|
||||
yield context
|
||||
|
||||
@Method(
|
||||
__doc__='''
|
||||
Builds the CSV writer, a.k.a an object we can pass a field collection to be written as one line in the
|
||||
target file.
|
||||
|
||||
Defaults to builtin csv.writer(...).writerow, but can be overriden to fit your special needs.
|
||||
'''
|
||||
)
|
||||
def writer_factory(self, file):
|
||||
return csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol).writerow
|
||||
|
||||
def write(self, fs, file, lineno, context, *args, _argnames=None):
|
||||
try:
|
||||
writer = context.writer
|
||||
except AttributeError:
|
||||
context.writer = self.writer_factory(file)
|
||||
writer = context.writer
|
||||
context.headers = self.headers or _argnames
|
||||
|
||||
if context.headers and not lineno:
|
||||
writer(context.headers)
|
||||
|
||||
def write(self, fs, file, lineno, writer, headers, **row):
|
||||
if not lineno:
|
||||
headers.set(headers.value or row.keys())
|
||||
writer.writerow(headers.get())
|
||||
writer.writerow(row.get(header, '') for header in headers.get())
|
||||
lineno += 1
|
||||
|
||||
if context.headers:
|
||||
try:
|
||||
row = [args[i] for i, header in enumerate(context.headers) if header]
|
||||
except IndexError:
|
||||
warnings.warn(
|
||||
'At line #{}, expected {} fields but only got {}. Padding with empty strings.'.format(
|
||||
lineno, len(context.headers), len(args)
|
||||
)
|
||||
)
|
||||
row = [(args[i] if i < len(args) else '') for i, header in enumerate(context.headers) if header]
|
||||
else:
|
||||
row = args
|
||||
|
||||
writer(row)
|
||||
|
||||
return NOT_MODIFIED
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import pickle
|
||||
|
||||
from bonobo.config import Option
|
||||
from bonobo.config.options import RemovedOption
|
||||
from bonobo.config.processors import ContextProcessor
|
||||
from bonobo.constants import NOT_MODIFIED
|
||||
from bonobo.nodes.io.base import FileHandler
|
||||
|
||||
@ -52,8 +52,9 @@ class Bag:
|
||||
# Otherwise, type will handle that for us.
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, *args, _flags=None, _parent=None, **kwargs):
|
||||
def __init__(self, *args, _flags=None, _parent=None, _argnames=None, **kwargs):
|
||||
self._flags = type(self).default_flags + (_flags or ())
|
||||
self._argnames = _argnames
|
||||
self._parent = _parent
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0:
|
||||
@ -115,9 +116,13 @@ class Bag:
|
||||
def flags(self):
|
||||
return self._flags
|
||||
|
||||
@property
|
||||
def specials(self):
|
||||
return {k: self.__dict__[k] for k in ('_argnames', ) if k in self.__dict__ and self.__dict__[k]}
|
||||
|
||||
def apply(self, func_or_iter, *args, **kwargs):
|
||||
if callable(func_or_iter):
|
||||
return func_or_iter(*args, *self.args, **kwargs, **self.kwargs)
|
||||
return func_or_iter(*args, *self.args, **kwargs, **self.kwargs, **self.specials)
|
||||
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
try:
|
||||
@ -148,7 +153,7 @@ class Bag:
|
||||
|
||||
@classmethod
|
||||
def inherit(cls, *args, **kwargs):
|
||||
return cls(*args, _flags=(INHERIT_INPUT,), **kwargs)
|
||||
return cls(*args, _flags=(INHERIT_INPUT, ), **kwargs)
|
||||
|
||||
def __eq__(self, other):
|
||||
# XXX there are overlapping cases, but this is very handy for now. Let's think about it later.
|
||||
@ -176,9 +181,12 @@ class Bag:
|
||||
|
||||
return len(self.args) == 1 and not self.kwargs and self.args[0] == other
|
||||
|
||||
def args_as_dict(self):
|
||||
return dict(zip(self._argnames, self.args))
|
||||
|
||||
|
||||
class LoopbackBag(Bag):
|
||||
default_flags = (LOOPBACK,)
|
||||
default_flags = (LOOPBACK, )
|
||||
|
||||
|
||||
class ErrorBag(Bag):
|
||||
|
||||
@ -16,6 +16,8 @@ def ensure_tuple(tuple_or_mixed):
|
||||
:return: tuple
|
||||
|
||||
"""
|
||||
if tuple_or_mixed is None:
|
||||
return ()
|
||||
if isinstance(tuple_or_mixed, tuple):
|
||||
return tuple_or_mixed
|
||||
return (tuple_or_mixed, )
|
||||
|
||||
@ -8,7 +8,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from bonobo import open_fs, Token, __main__, get_examples_path
|
||||
from bonobo import open_fs, Token, __main__, get_examples_path, Bag
|
||||
from bonobo.commands import entrypoint
|
||||
from bonobo.execution.contexts.graph import GraphExecutionContext
|
||||
from bonobo.execution.contexts.node import NodeExecutionContext
|
||||
@ -57,6 +57,9 @@ class BufferingContext:
|
||||
def get_buffer(self):
|
||||
return self.buffer
|
||||
|
||||
def get_buffer_args_as_dicts(self):
|
||||
return list(map(lambda x: x.args_as_dict() if isinstance(x, Bag) else dict(x), self.buffer))
|
||||
|
||||
|
||||
class BufferingNodeExecutionContext(BufferingContext, NodeExecutionContext):
|
||||
def __init__(self, *args, buffer=None, **kwargs):
|
||||
|
||||
@ -58,7 +58,7 @@ def test_define_with_decorator():
|
||||
Concrete = MethodBasedConfigurable(my_handler)
|
||||
|
||||
assert callable(Concrete.handler)
|
||||
assert Concrete.handler == my_handler
|
||||
assert Concrete.handler.__func__ == my_handler
|
||||
|
||||
with inspect_node(Concrete) as ci:
|
||||
assert ci.type == MethodBasedConfigurable
|
||||
|
||||
@ -17,17 +17,21 @@ def test_write_csv_ioformat_arg0(tmpdir):
|
||||
CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0),
|
||||
|
||||
|
||||
@pytest.mark.parametrize('add_kwargs', (
|
||||
{},
|
||||
{
|
||||
'ioformat': settings.IOFORMAT_KWARGS,
|
||||
},
|
||||
))
|
||||
def test_write_csv_to_file_kwargs(tmpdir, add_kwargs):
|
||||
def test_write_csv_to_file_no_headers(tmpdir):
|
||||
fs, filename, services = csv_tester.get_services_for_writer(tmpdir)
|
||||
|
||||
with NodeExecutionContext(CsvWriter(filename, **add_kwargs), services=services) as context:
|
||||
context.write_sync({'foo': 'bar'}, {'foo': 'baz', 'ignore': 'this'})
|
||||
with NodeExecutionContext(CsvWriter(filename), services=services) as context:
|
||||
context.write_sync(('bar', ), ('baz', 'boo'))
|
||||
|
||||
with fs.open(filename) as fp:
|
||||
assert fp.read() == 'bar\nbaz;boo\n'
|
||||
|
||||
|
||||
def test_write_csv_to_file_with_headers(tmpdir):
|
||||
fs, filename, services = csv_tester.get_services_for_writer(tmpdir)
|
||||
|
||||
with NodeExecutionContext(CsvWriter(filename, headers='foo'), services=services) as context:
|
||||
context.write_sync(('bar', ), ('baz', 'boo'))
|
||||
|
||||
with fs.open(filename) as fp:
|
||||
assert fp.read() == 'foo\nbar\nbaz\n'
|
||||
@ -45,7 +49,7 @@ def test_read_csv_from_file_kwargs(tmpdir):
|
||||
) as context:
|
||||
context.write_sync(())
|
||||
|
||||
assert context.get_buffer() == [
|
||||
assert context.get_buffer_args_as_dicts() == [
|
||||
{
|
||||
'a': 'a foo',
|
||||
'b': 'b foo',
|
||||
|
||||
Reference in New Issue
Block a user