Inheritance of bags and better jupyter output for pretty printer.
This commit is contained in:
@ -11,9 +11,18 @@ class Token:
|
|||||||
BEGIN = Token('Begin')
|
BEGIN = Token('Begin')
|
||||||
END = Token('End')
|
END = Token('End')
|
||||||
|
|
||||||
INHERIT_INPUT = Token('InheritInput')
|
|
||||||
LOOPBACK = Token('Loopback')
|
class Flag(Token):
|
||||||
NOT_MODIFIED = Token('NotModified')
|
must_be_first = False
|
||||||
|
must_be_last = False
|
||||||
|
allows_data = True
|
||||||
|
|
||||||
|
|
||||||
|
INHERIT = Flag('Inherit')
|
||||||
|
NOT_MODIFIED = Flag('NotModified')
|
||||||
|
NOT_MODIFIED.must_be_first = True
|
||||||
|
NOT_MODIFIED.must_be_last = True
|
||||||
|
NOT_MODIFIED.allows_data = False
|
||||||
|
|
||||||
EMPTY = tuple()
|
EMPTY = tuple()
|
||||||
|
|
||||||
|
|||||||
@ -27,6 +27,6 @@ def get_graph_options(options):
|
|||||||
_print = options.pop('print', False)
|
_print = options.pop('print', False)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'_limit': (bonobo.Limit(_limit),) if _limit else (),
|
'_limit': (bonobo.Limit(_limit), ) if _limit else (),
|
||||||
'_print': (bonobo.PrettyPrinter(),) if _print else (),
|
'_print': (bonobo.PrettyPrinter(), ) if _print else (),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -66,11 +66,16 @@ def get_services():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = examples.get_argument_parser()
|
parser = examples.get_argument_parser()
|
||||||
parser.add_argument('--target', '-t', choices=graphs.keys(), nargs='+')
|
parser.add_argument(
|
||||||
|
'--target', '-t', choices=graphs.keys(), nargs='+'
|
||||||
|
)
|
||||||
|
|
||||||
with bonobo.parse_args(parser) as options:
|
with bonobo.parse_args(parser) as options:
|
||||||
graph_options = examples.get_graph_options(options)
|
graph_options = examples.get_graph_options(options)
|
||||||
graph_names = list(options['target'] if options['target'] else sorted(graphs.keys()))
|
graph_names = list(
|
||||||
|
options['target']
|
||||||
|
if options['target'] else sorted(graphs.keys())
|
||||||
|
)
|
||||||
|
|
||||||
graph = bonobo.Graph()
|
graph = bonobo.Graph()
|
||||||
for name in graph_names:
|
for name in graph_names:
|
||||||
|
|||||||
@ -0,0 +1,9 @@
|
|||||||
|
from bonobo.execution.contexts.graph import GraphExecutionContext
|
||||||
|
from bonobo.execution.contexts.node import NodeExecutionContext
|
||||||
|
from bonobo.execution.contexts.plugin import PluginExecutionContext
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'GraphExecutionContext',
|
||||||
|
'NodeExecutionContext',
|
||||||
|
'PluginExecutionContext',
|
||||||
|
]
|
||||||
|
|||||||
@ -7,11 +7,11 @@ from types import GeneratorType
|
|||||||
|
|
||||||
from bonobo.config import create_container
|
from bonobo.config import create_container
|
||||||
from bonobo.config.processors import ContextCurrifier
|
from bonobo.config.processors import ContextCurrifier
|
||||||
from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD, Token
|
from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD, Token, Flag, INHERIT
|
||||||
from bonobo.errors import InactiveReadableError, UnrecoverableError, UnrecoverableTypeError
|
from bonobo.errors import InactiveReadableError, UnrecoverableError, UnrecoverableTypeError
|
||||||
from bonobo.execution.contexts.base import BaseContext
|
from bonobo.execution.contexts.base import BaseContext
|
||||||
from bonobo.structs.inputs import Input
|
from bonobo.structs.inputs import Input
|
||||||
from bonobo.util import get_name, istuple, isconfigurabletype, ensure_tuple
|
from bonobo.util import get_name, isconfigurabletype, ensure_tuple
|
||||||
from bonobo.util.bags import BagType
|
from bonobo.util.bags import BagType
|
||||||
from bonobo.util.statistics import WithStatistics
|
from bonobo.util.statistics import WithStatistics
|
||||||
|
|
||||||
@ -292,20 +292,24 @@ class NodeExecutionContext(BaseContext, WithStatistics):
|
|||||||
|
|
||||||
def _cast(self, _input, _output):
|
def _cast(self, _input, _output):
|
||||||
"""
|
"""
|
||||||
Transforms a pair of input/output into what is the real output.
|
Transforms a pair of input/output into the real slim output.
|
||||||
|
|
||||||
:param _input: Bag
|
:param _input: Bag
|
||||||
:param _output: mixed
|
:param _output: mixed
|
||||||
:return: Bag
|
:return: Bag
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if _output is NOT_MODIFIED:
|
tokens, _output = split_token(_output)
|
||||||
if self._output_type is None:
|
|
||||||
return _input
|
|
||||||
else:
|
|
||||||
return self._output_type(*_input)
|
|
||||||
|
|
||||||
return ensure_tuple(_output, cls=(self.output_type or tuple))
|
if NOT_MODIFIED in tokens:
|
||||||
|
return ensure_tuple(_input, cls=(self.output_type or tuple))
|
||||||
|
|
||||||
|
if INHERIT in tokens:
|
||||||
|
if self._output_type is None:
|
||||||
|
self._output_type = concat_types(self._input_type, self._input_length, self._output_type, len(_output))
|
||||||
|
_output = _input + ensure_tuple(_output)
|
||||||
|
|
||||||
|
return ensure_tuple(_output, cls=(self._output_type or tuple))
|
||||||
|
|
||||||
def _send(self, value, _control=False):
|
def _send(self, value, _control=False):
|
||||||
"""
|
"""
|
||||||
@ -330,26 +334,44 @@ class NodeExecutionContext(BaseContext, WithStatistics):
|
|||||||
|
|
||||||
|
|
||||||
def isflag(param):
|
def isflag(param):
|
||||||
return isinstance(param, Token) and param in (NOT_MODIFIED, )
|
return isinstance(param, Flag)
|
||||||
|
|
||||||
|
|
||||||
def split_tokens(output):
|
def split_token(output):
|
||||||
"""
|
"""
|
||||||
Split an output into token tuple, real output tuple.
|
Split an output into token tuple, real output tuple.
|
||||||
|
|
||||||
:param output:
|
:param output:
|
||||||
:return: tuple, tuple
|
:return: tuple, tuple
|
||||||
"""
|
"""
|
||||||
if isinstance(output, Token):
|
|
||||||
# just a flag
|
|
||||||
return (output, ), ()
|
|
||||||
|
|
||||||
if not istuple(output):
|
output = ensure_tuple(output)
|
||||||
# no flag
|
|
||||||
return (), (output, )
|
|
||||||
|
|
||||||
i = 0
|
flags, i, len_output, data_allowed = set(), 0, len(output), True
|
||||||
while isflag(output[i]):
|
while i < len_output and isflag(output[i]):
|
||||||
|
if output[i].must_be_first and i:
|
||||||
|
raise ValueError('{} flag must be first.'.format(output[i]))
|
||||||
|
if i and output[i - 1].must_be_last:
|
||||||
|
raise ValueError('{} flag must be last.'.format(output[i - 1]))
|
||||||
|
if output[i] in flags:
|
||||||
|
raise ValueError('Duplicate flag {}.'.format(output[i]))
|
||||||
|
flags.add(output[i])
|
||||||
|
data_allowed &= output[i].allows_data
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return output[:i], output[i:]
|
output = output[i:]
|
||||||
|
if not data_allowed and len(output):
|
||||||
|
raise ValueError('Output data provided after a flag that does not allow data.')
|
||||||
|
return flags, output
|
||||||
|
|
||||||
|
|
||||||
|
def concat_types(t1, l1, t2, l2):
|
||||||
|
t1, t2 = t1 or tuple, t2 or tuple
|
||||||
|
|
||||||
|
if t1 == t2 == tuple:
|
||||||
|
return tuple
|
||||||
|
|
||||||
|
f1 = t1._fields if hasattr(t1, '_fields') else tuple(range(l1))
|
||||||
|
f2 = t2._fields if hasattr(t2, '_fields') else tuple(range(l2))
|
||||||
|
|
||||||
|
return BagType('Inherited', f1 + f2)
|
||||||
|
|||||||
@ -1,11 +1,7 @@
|
|||||||
import functools
|
import functools
|
||||||
|
import html
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
|
||||||
import pprint
|
import pprint
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
from bonobo.util import ensure_tuple
|
|
||||||
from mondrian import term
|
|
||||||
|
|
||||||
from bonobo import settings
|
from bonobo import settings
|
||||||
from bonobo.config import Configurable, Option, Method, use_raw_input, use_context, use_no_input
|
from bonobo.config import Configurable, Option, Method, use_raw_input, use_context, use_no_input
|
||||||
@ -14,6 +10,7 @@ from bonobo.config.processors import ContextProcessor, use_context_processor
|
|||||||
from bonobo.constants import NOT_MODIFIED
|
from bonobo.constants import NOT_MODIFIED
|
||||||
from bonobo.util.objects import ValueHolder
|
from bonobo.util.objects import ValueHolder
|
||||||
from bonobo.util.term import CLEAR_EOL
|
from bonobo.util.term import CLEAR_EOL
|
||||||
|
from mondrian import term
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'FixedWindow',
|
'FixedWindow',
|
||||||
@ -94,29 +91,41 @@ class PrettyPrinter(Configurable):
|
|||||||
|
|
||||||
@ContextProcessor
|
@ContextProcessor
|
||||||
def context(self, context):
|
def context(self, context):
|
||||||
|
context.setdefault('_jupyter_html', None)
|
||||||
yield context
|
yield context
|
||||||
|
if context._jupyter_html is not None:
|
||||||
|
from IPython.display import display, HTML
|
||||||
|
display(HTML('\n'.join(['<table>'] + context._jupyter_html + ['</table>'])))
|
||||||
|
|
||||||
def __call__(self, context, *args, **kwargs):
|
def __call__(self, context, *args, **kwargs):
|
||||||
quiet = settings.QUIET.get()
|
if not settings.QUIET:
|
||||||
formater = self._format_quiet if quiet else self._format_console
|
if term.isjupyter:
|
||||||
|
self.print_jupyter(context, *args, **kwargs)
|
||||||
if not quiet:
|
return NOT_MODIFIED
|
||||||
print('\u250e' + '\u2500' * (self.max_width - 1))
|
if term.istty:
|
||||||
|
self.print_console(context, *args, **kwargs)
|
||||||
for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())):
|
return NOT_MODIFIED
|
||||||
if self.filter(index, key, value):
|
|
||||||
print(formater(index, key, value, fields=context.get_input_fields()))
|
|
||||||
|
|
||||||
if not quiet:
|
|
||||||
print('\u2516' + '\u2500' * (self.max_width - 1))
|
|
||||||
|
|
||||||
|
self.print_quiet(context, *args, **kwargs)
|
||||||
return NOT_MODIFIED
|
return NOT_MODIFIED
|
||||||
|
|
||||||
def _format_quiet(self, index, key, value, *, fields=None):
|
def print_quiet(self, context, *args, **kwargs):
|
||||||
|
for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())):
|
||||||
|
if self.filter(index, key, value):
|
||||||
|
print(self.format_quiet(index, key, value, fields=context.get_input_fields()))
|
||||||
|
|
||||||
|
def format_quiet(self, index, key, value, *, fields=None):
|
||||||
# XXX should we implement argnames here ?
|
# XXX should we implement argnames here ?
|
||||||
return ' '.join(((' ' if index else '-'), str(key), ':', str(value).strip()))
|
return ' '.join(((' ' if index else '-'), str(key), ':', str(value).strip()))
|
||||||
|
|
||||||
def _format_console(self, index, key, value, *, fields=None):
|
def print_console(self, context, *args, **kwargs):
|
||||||
|
print('\u250e' + '\u2500' * (self.max_width - 1))
|
||||||
|
for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())):
|
||||||
|
if self.filter(index, key, value):
|
||||||
|
print(self.format_console(index, key, value, fields=context.get_input_fields()))
|
||||||
|
print('\u2516' + '\u2500' * (self.max_width - 1))
|
||||||
|
|
||||||
|
def format_console(self, index, key, value, *, fields=None):
|
||||||
fields = fields or []
|
fields = fields or []
|
||||||
if not isinstance(key, str):
|
if not isinstance(key, str):
|
||||||
if len(fields) >= key and str(key) != str(fields[key]):
|
if len(fields) >= key and str(key) != str(fields[key]):
|
||||||
@ -136,6 +145,21 @@ class PrettyPrinter(Configurable):
|
|||||||
).strip()
|
).strip()
|
||||||
return '{}{}{}'.format(prefix, repr_of_value.replace('\n', CLEAR_EOL + '\n'), CLEAR_EOL)
|
return '{}{}{}'.format(prefix, repr_of_value.replace('\n', CLEAR_EOL + '\n'), CLEAR_EOL)
|
||||||
|
|
||||||
|
def print_jupyter(self, context, *args):
|
||||||
|
if not context._jupyter_html:
|
||||||
|
context._jupyter_html = [
|
||||||
|
'<thead><tr>',
|
||||||
|
*map('<th>{}</th>'.format, map(html.escape, map(str,
|
||||||
|
context.get_input_fields() or range(len(args))))),
|
||||||
|
'</tr></thead>',
|
||||||
|
]
|
||||||
|
|
||||||
|
context._jupyter_html += [
|
||||||
|
'<tr>',
|
||||||
|
*map('<td>{}</td>'.format, map(html.escape, map(repr, args))),
|
||||||
|
'</tr>',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@use_no_input
|
@use_no_input
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
|
|||||||
@ -46,6 +46,9 @@ class Setting:
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.get() == other
|
return self.get() == other
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self.get())
|
||||||
|
|
||||||
def set(self, value):
|
def set(self, value):
|
||||||
value = self.formatter(value) if self.formatter else value
|
value = self.formatter(value) if self.formatter else value
|
||||||
if self.validator and not self.validator(value):
|
if self.validator and not self.validator(value):
|
||||||
|
|||||||
@ -192,6 +192,4 @@ rst_epilog = """
|
|||||||
|
|
||||||
.. |longversion| replace:: v.{version}
|
.. |longversion| replace:: v.{version}
|
||||||
|
|
||||||
""".format(
|
""".format(version=version, )
|
||||||
version=version,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -3,8 +3,8 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from bonobo import Graph
|
from bonobo import Graph
|
||||||
from bonobo.constants import EMPTY
|
from bonobo.constants import EMPTY, NOT_MODIFIED, INHERIT
|
||||||
from bonobo.execution.contexts.node import NodeExecutionContext
|
from bonobo.execution.contexts.node import NodeExecutionContext, split_token
|
||||||
from bonobo.execution.strategies import NaiveStrategy
|
from bonobo.execution.strategies import NaiveStrategy
|
||||||
from bonobo.util.testing import BufferingNodeExecutionContext, BufferingGraphExecutionContext
|
from bonobo.util.testing import BufferingNodeExecutionContext, BufferingGraphExecutionContext
|
||||||
|
|
||||||
@ -224,3 +224,35 @@ def test_node_lifecycle_with_kill():
|
|||||||
|
|
||||||
ctx.stop()
|
ctx.stop()
|
||||||
assert all((ctx.started, ctx.killed, ctx.stopped)) and not ctx.alive
|
assert all((ctx.started, ctx.killed, ctx.stopped)) and not ctx.alive
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_token():
|
||||||
|
assert split_token(('foo', 'bar')) == (set(), ('foo', 'bar'))
|
||||||
|
assert split_token(()) == (set(), ())
|
||||||
|
assert split_token('') == (set(), ('', ))
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_token_duplicate():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
split_token((NOT_MODIFIED, NOT_MODIFIED))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
split_token((INHERIT, INHERIT))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
split_token((INHERIT, NOT_MODIFIED, INHERIT))
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_token_not_modified():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
split_token((NOT_MODIFIED, 'foo', 'bar'))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
split_token((NOT_MODIFIED, INHERIT))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
split_token((INHERIT, NOT_MODIFIED))
|
||||||
|
assert split_token(NOT_MODIFIED) == ({NOT_MODIFIED}, ())
|
||||||
|
assert split_token((NOT_MODIFIED, )) == ({NOT_MODIFIED}, ())
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_token_inherit():
|
||||||
|
assert split_token(INHERIT) == ({INHERIT}, ())
|
||||||
|
assert split_token((INHERIT, )) == ({INHERIT}, ())
|
||||||
|
assert split_token((INHERIT, 'foo', 'bar')) == ({INHERIT}, ('foo', 'bar'))
|
||||||
|
|||||||
27
tests/features/test_inherit.py
Normal file
27
tests/features/test_inherit.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from bonobo.constants import INHERIT
|
||||||
|
from bonobo.util.testing import BufferingNodeExecutionContext
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
('Hello', ),
|
||||||
|
('Goodbye', ),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def append(*args):
|
||||||
|
return INHERIT, '!'
|
||||||
|
|
||||||
|
|
||||||
|
def test_inherit():
|
||||||
|
with BufferingNodeExecutionContext(append) as context:
|
||||||
|
context.write_sync(*messages)
|
||||||
|
|
||||||
|
assert context.get_buffer() == list(map(lambda x: x + ('!', ), messages))
|
||||||
|
|
||||||
|
|
||||||
|
def test_inherit_bag_tuple():
|
||||||
|
with BufferingNodeExecutionContext(append) as context:
|
||||||
|
context.set_input_fields(['message'])
|
||||||
|
context.write_sync(*messages)
|
||||||
|
|
||||||
|
assert context.get_output_fields() == ('message', '0')
|
||||||
|
assert context.get_buffer() == list(map(lambda x: x + ('!', ), messages))
|
||||||
Reference in New Issue
Block a user