From d8c0dfe11a514dc24b43944ce73c751b1a174cce Mon Sep 17 00:00:00 2001 From: Romain Dorgueil Date: Tue, 28 Nov 2017 21:58:01 +0100 Subject: [PATCH] Inheritance of bags and better jupyter output for pretty printer. --- bonobo/constants.py | 15 +++++-- bonobo/examples/__init__.py | 4 +- bonobo/examples/datasets/__main__.py | 9 +++- bonobo/execution/contexts/__init__.py | 9 ++++ bonobo/execution/contexts/node.py | 62 ++++++++++++++++++--------- bonobo/nodes/basics.py | 62 +++++++++++++++++++-------- bonobo/settings.py | 3 ++ docs/conf.py | 4 +- tests/execution/contexts/test_node.py | 36 +++++++++++++++- tests/features/test_inherit.py | 27 ++++++++++++ 10 files changed, 180 insertions(+), 51 deletions(-) create mode 100644 tests/features/test_inherit.py diff --git a/bonobo/constants.py b/bonobo/constants.py index fceb8f9..b1a199c 100644 --- a/bonobo/constants.py +++ b/bonobo/constants.py @@ -11,9 +11,18 @@ class Token: BEGIN = Token('Begin') END = Token('End') -INHERIT_INPUT = Token('InheritInput') -LOOPBACK = Token('Loopback') -NOT_MODIFIED = Token('NotModified') + +class Flag(Token): + must_be_first = False + must_be_last = False + allows_data = True + + +INHERIT = Flag('Inherit') +NOT_MODIFIED = Flag('NotModified') +NOT_MODIFIED.must_be_first = True +NOT_MODIFIED.must_be_last = True +NOT_MODIFIED.allows_data = False EMPTY = tuple() diff --git a/bonobo/examples/__init__.py b/bonobo/examples/__init__.py index e1815f8..ec68fc5 100644 --- a/bonobo/examples/__init__.py +++ b/bonobo/examples/__init__.py @@ -27,6 +27,6 @@ def get_graph_options(options): _print = options.pop('print', False) return { - '_limit': (bonobo.Limit(_limit),) if _limit else (), - '_print': (bonobo.PrettyPrinter(),) if _print else (), + '_limit': (bonobo.Limit(_limit), ) if _limit else (), + '_print': (bonobo.PrettyPrinter(), ) if _print else (), } diff --git a/bonobo/examples/datasets/__main__.py b/bonobo/examples/datasets/__main__.py index 768ac5c..55a7fde 100644 --- a/bonobo/examples/datasets/__main__.py +++ b/bonobo/examples/datasets/__main__.py @@ -66,11 +66,16 @@ def get_services(): if __name__ == '__main__': parser = examples.get_argument_parser() - parser.add_argument('--target', '-t', choices=graphs.keys(), nargs='+') + parser.add_argument( + '--target', '-t', choices=graphs.keys(), nargs='+' + ) with bonobo.parse_args(parser) as options: graph_options = examples.get_graph_options(options) - graph_names = list(options['target'] if options['target'] else sorted(graphs.keys())) + graph_names = list( + options['target'] + if options['target'] else sorted(graphs.keys()) + ) graph = bonobo.Graph() for name in graph_names: diff --git a/bonobo/execution/contexts/__init__.py b/bonobo/execution/contexts/__init__.py index e69de29..4c462c5 100644 --- a/bonobo/execution/contexts/__init__.py +++ b/bonobo/execution/contexts/__init__.py @@ -0,0 +1,9 @@ +from bonobo.execution.contexts.graph import GraphExecutionContext +from bonobo.execution.contexts.node import NodeExecutionContext +from bonobo.execution.contexts.plugin import PluginExecutionContext + +__all__ = [ + 'GraphExecutionContext', + 'NodeExecutionContext', + 'PluginExecutionContext', +] diff --git a/bonobo/execution/contexts/node.py b/bonobo/execution/contexts/node.py index 194cf36..316d9b8 100644 --- a/bonobo/execution/contexts/node.py +++ b/bonobo/execution/contexts/node.py @@ -7,11 +7,11 @@ from types import GeneratorType from bonobo.config import create_container from bonobo.config.processors import ContextCurrifier -from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD, Token +from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD, Token, Flag, INHERIT from bonobo.errors import InactiveReadableError, UnrecoverableError, UnrecoverableTypeError from bonobo.execution.contexts.base import BaseContext from bonobo.structs.inputs import Input -from bonobo.util import get_name, istuple, isconfigurabletype, ensure_tuple +from bonobo.util import get_name, isconfigurabletype, ensure_tuple from bonobo.util.bags import BagType from bonobo.util.statistics import WithStatistics @@ -292,20 +292,24 @@ class NodeExecutionContext(BaseContext, WithStatistics): def _cast(self, _input, _output): """ - Transforms a pair of input/output into what is the real output. + Transforms a pair of input/output into the real slim output. :param _input: Bag :param _output: mixed :return: Bag """ - if _output is NOT_MODIFIED: - if self._output_type is None: - return _input - else: - return self._output_type(*_input) + tokens, _output = split_token(_output) - return ensure_tuple(_output, cls=(self.output_type or tuple)) + if NOT_MODIFIED in tokens: + return ensure_tuple(_input, cls=(self.output_type or tuple)) + + if INHERIT in tokens: + if self._output_type is None: + self._output_type = concat_types(self._input_type, self._input_length, self._output_type, len(_output)) + _output = _input + ensure_tuple(_output) + + return ensure_tuple(_output, cls=(self._output_type or tuple)) def _send(self, value, _control=False): """ @@ -330,26 +334,44 @@ class NodeExecutionContext(BaseContext, WithStatistics): def isflag(param): - return isinstance(param, Token) and param in (NOT_MODIFIED, ) + return isinstance(param, Flag) -def split_tokens(output): +def split_token(output): """ Split an output into token tuple, real output tuple. :param output: :return: tuple, tuple """ - if isinstance(output, Token): - # just a flag - return (output, ), () - if not istuple(output): - # no flag - return (), (output, ) + output = ensure_tuple(output) - i = 0 - while isflag(output[i]): + flags, i, len_output, data_allowed = set(), 0, len(output), True + while i < len_output and isflag(output[i]): + if output[i].must_be_first and i: + raise ValueError('{} flag must be first.'.format(output[i])) + if i and output[i - 1].must_be_last: + raise ValueError('{} flag must be last.'.format(output[i - 1])) + if output[i] in flags: + raise ValueError('Duplicate flag {}.'.format(output[i])) + flags.add(output[i]) + data_allowed &= output[i].allows_data i += 1 - return output[:i], output[i:] + output = output[i:] + if not data_allowed and len(output): + raise ValueError('Output data provided after a flag that does not allow data.') + return flags, output + + +def concat_types(t1, l1, t2, l2): + t1, t2 = t1 or tuple, t2 or tuple + + if t1 == t2 == tuple: + return tuple + + f1 = t1._fields if hasattr(t1, '_fields') else tuple(range(l1)) + f2 = t2._fields if hasattr(t2, '_fields') else tuple(range(l2)) + + return BagType('Inherited', f1 + f2) diff --git a/bonobo/nodes/basics.py b/bonobo/nodes/basics.py index 9710ef7..ecf6827 100644 --- a/bonobo/nodes/basics.py +++ b/bonobo/nodes/basics.py @@ -1,11 +1,7 @@ import functools +import html import itertools -import operator import pprint -from functools import reduce - -from bonobo.util import ensure_tuple -from mondrian import term from bonobo import settings from bonobo.config import Configurable, Option, Method, use_raw_input, use_context, use_no_input @@ -14,6 +10,7 @@ from bonobo.config.processors import ContextProcessor, use_context_processor from bonobo.constants import NOT_MODIFIED from bonobo.util.objects import ValueHolder from bonobo.util.term import CLEAR_EOL +from mondrian import term __all__ = [ 'FixedWindow', @@ -94,29 +91,41 @@ class PrettyPrinter(Configurable): @ContextProcessor def context(self, context): + context.setdefault('_jupyter_html', None) yield context + if context._jupyter_html is not None: + from IPython.display import display, HTML + display(HTML('\n'.join([''] + context._jupyter_html + ['
']))) def __call__(self, context, *args, **kwargs): - quiet = settings.QUIET.get() - formater = self._format_quiet if quiet else self._format_console - - if not quiet: - print('\u250e' + '\u2500' * (self.max_width - 1)) - - for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())): - if self.filter(index, key, value): - print(formater(index, key, value, fields=context.get_input_fields())) - - if not quiet: - print('\u2516' + '\u2500' * (self.max_width - 1)) + if not settings.QUIET: + if term.isjupyter: + self.print_jupyter(context, *args, **kwargs) + return NOT_MODIFIED + if term.istty: + self.print_console(context, *args, **kwargs) + return NOT_MODIFIED + self.print_quiet(context, *args, **kwargs) return NOT_MODIFIED - def _format_quiet(self, index, key, value, *, fields=None): + def print_quiet(self, context, *args, **kwargs): + for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())): + if self.filter(index, key, value): + print(self.format_quiet(index, key, value, fields=context.get_input_fields())) + + def format_quiet(self, index, key, value, *, fields=None): # XXX should we implement argnames here ? return ' '.join(((' ' if index else '-'), str(key), ':', str(value).strip())) - def _format_console(self, index, key, value, *, fields=None): + def print_console(self, context, *args, **kwargs): + print('\u250e' + '\u2500' * (self.max_width - 1)) + for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())): + if self.filter(index, key, value): + print(self.format_console(index, key, value, fields=context.get_input_fields())) + print('\u2516' + '\u2500' * (self.max_width - 1)) + + def format_console(self, index, key, value, *, fields=None): fields = fields or [] if not isinstance(key, str): if len(fields) >= key and str(key) != str(fields[key]): @@ -136,6 +145,21 @@ class PrettyPrinter(Configurable): ).strip() return '{}{}{}'.format(prefix, repr_of_value.replace('\n', CLEAR_EOL + '\n'), CLEAR_EOL) + def print_jupyter(self, context, *args): + if not context._jupyter_html: + context._jupyter_html = [ + '', + *map('{}'.format, map(html.escape, map(str, + context.get_input_fields() or range(len(args))))), + '', + ] + + context._jupyter_html += [ + '', + *map('{}'.format, map(html.escape, map(repr, args))), + '', + ] + @use_no_input def noop(*args, **kwargs): diff --git a/bonobo/settings.py b/bonobo/settings.py index fdc4412..799ba3d 100644 --- a/bonobo/settings.py +++ b/bonobo/settings.py @@ -46,6 +46,9 @@ class Setting: def __eq__(self, other): return self.get() == other + def __bool__(self): + return bool(self.get()) + def set(self, value): value = self.formatter(value) if self.formatter else value if self.validator and not self.validator(value): diff --git a/docs/conf.py b/docs/conf.py index 46a9b55..53ed816 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -192,6 +192,4 @@ rst_epilog = """ .. |longversion| replace:: v.{version} -""".format( - version=version, -) +""".format(version=version, ) diff --git a/tests/execution/contexts/test_node.py b/tests/execution/contexts/test_node.py index 0ebae6d..c9c8f1f 100644 --- a/tests/execution/contexts/test_node.py +++ b/tests/execution/contexts/test_node.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import pytest from bonobo import Graph -from bonobo.constants import EMPTY -from bonobo.execution.contexts.node import NodeExecutionContext +from bonobo.constants import EMPTY, NOT_MODIFIED, INHERIT +from bonobo.execution.contexts.node import NodeExecutionContext, split_token from bonobo.execution.strategies import NaiveStrategy from bonobo.util.testing import BufferingNodeExecutionContext, BufferingGraphExecutionContext @@ -224,3 +224,35 @@ def test_node_lifecycle_with_kill(): ctx.stop() assert all((ctx.started, ctx.killed, ctx.stopped)) and not ctx.alive + + +def test_split_token(): + assert split_token(('foo', 'bar')) == (set(), ('foo', 'bar')) + assert split_token(()) == (set(), ()) + assert split_token('') == (set(), ('', )) + + +def test_split_token_duplicate(): + with pytest.raises(ValueError): + split_token((NOT_MODIFIED, NOT_MODIFIED)) + with pytest.raises(ValueError): + split_token((INHERIT, INHERIT)) + with pytest.raises(ValueError): + split_token((INHERIT, NOT_MODIFIED, INHERIT)) + + +def test_split_token_not_modified(): + with pytest.raises(ValueError): + split_token((NOT_MODIFIED, 'foo', 'bar')) + with pytest.raises(ValueError): + split_token((NOT_MODIFIED, INHERIT)) + with pytest.raises(ValueError): + split_token((INHERIT, NOT_MODIFIED)) + assert split_token(NOT_MODIFIED) == ({NOT_MODIFIED}, ()) + assert split_token((NOT_MODIFIED, )) == ({NOT_MODIFIED}, ()) + + +def test_split_token_inherit(): + assert split_token(INHERIT) == ({INHERIT}, ()) + assert split_token((INHERIT, )) == ({INHERIT}, ()) + assert split_token((INHERIT, 'foo', 'bar')) == ({INHERIT}, ('foo', 'bar')) diff --git a/tests/features/test_inherit.py b/tests/features/test_inherit.py new file mode 100644 index 0000000..92b943b --- /dev/null +++ b/tests/features/test_inherit.py @@ -0,0 +1,27 @@ +from bonobo.constants import INHERIT +from bonobo.util.testing import BufferingNodeExecutionContext + +messages = [ + ('Hello', ), + ('Goodbye', ), +] + + +def append(*args): + return INHERIT, '!' + + +def test_inherit(): + with BufferingNodeExecutionContext(append) as context: + context.write_sync(*messages) + + assert context.get_buffer() == list(map(lambda x: x + ('!', ), messages)) + + +def test_inherit_bag_tuple(): + with BufferingNodeExecutionContext(append) as context: + context.set_input_fields(['message']) + context.write_sync(*messages) + + assert context.get_output_fields() == ('message', '0') + assert context.get_buffer() == list(map(lambda x: x + ('!', ), messages))