Inheritance of bags and better jupyter output for pretty printer.

This commit is contained in:
Romain Dorgueil
2017-11-28 21:58:01 +01:00
parent c7ff06a742
commit d8c0dfe11a
10 changed files with 180 additions and 51 deletions

View File

@ -11,9 +11,18 @@ class Token:
BEGIN = Token('Begin') BEGIN = Token('Begin')
END = Token('End') END = Token('End')
INHERIT_INPUT = Token('InheritInput')
LOOPBACK = Token('Loopback') class Flag(Token):
NOT_MODIFIED = Token('NotModified') must_be_first = False
must_be_last = False
allows_data = True
INHERIT = Flag('Inherit')
NOT_MODIFIED = Flag('NotModified')
NOT_MODIFIED.must_be_first = True
NOT_MODIFIED.must_be_last = True
NOT_MODIFIED.allows_data = False
EMPTY = tuple() EMPTY = tuple()

View File

@ -66,11 +66,16 @@ def get_services():
if __name__ == '__main__': if __name__ == '__main__':
parser = examples.get_argument_parser() parser = examples.get_argument_parser()
parser.add_argument('--target', '-t', choices=graphs.keys(), nargs='+') parser.add_argument(
'--target', '-t', choices=graphs.keys(), nargs='+'
)
with bonobo.parse_args(parser) as options: with bonobo.parse_args(parser) as options:
graph_options = examples.get_graph_options(options) graph_options = examples.get_graph_options(options)
graph_names = list(options['target'] if options['target'] else sorted(graphs.keys())) graph_names = list(
options['target']
if options['target'] else sorted(graphs.keys())
)
graph = bonobo.Graph() graph = bonobo.Graph()
for name in graph_names: for name in graph_names:

View File

@ -0,0 +1,9 @@
from bonobo.execution.contexts.graph import GraphExecutionContext
from bonobo.execution.contexts.node import NodeExecutionContext
from bonobo.execution.contexts.plugin import PluginExecutionContext
__all__ = [
'GraphExecutionContext',
'NodeExecutionContext',
'PluginExecutionContext',
]

View File

@ -7,11 +7,11 @@ from types import GeneratorType
from bonobo.config import create_container from bonobo.config import create_container
from bonobo.config.processors import ContextCurrifier from bonobo.config.processors import ContextCurrifier
from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD, Token from bonobo.constants import NOT_MODIFIED, BEGIN, END, TICK_PERIOD, Token, Flag, INHERIT
from bonobo.errors import InactiveReadableError, UnrecoverableError, UnrecoverableTypeError from bonobo.errors import InactiveReadableError, UnrecoverableError, UnrecoverableTypeError
from bonobo.execution.contexts.base import BaseContext from bonobo.execution.contexts.base import BaseContext
from bonobo.structs.inputs import Input from bonobo.structs.inputs import Input
from bonobo.util import get_name, istuple, isconfigurabletype, ensure_tuple from bonobo.util import get_name, isconfigurabletype, ensure_tuple
from bonobo.util.bags import BagType from bonobo.util.bags import BagType
from bonobo.util.statistics import WithStatistics from bonobo.util.statistics import WithStatistics
@ -292,20 +292,24 @@ class NodeExecutionContext(BaseContext, WithStatistics):
def _cast(self, _input, _output): def _cast(self, _input, _output):
""" """
Transforms a pair of input/output into what is the real output. Transforms a pair of input/output into the real slim output.
:param _input: Bag :param _input: Bag
:param _output: mixed :param _output: mixed
:return: Bag :return: Bag
""" """
if _output is NOT_MODIFIED: tokens, _output = split_token(_output)
if self._output_type is None:
return _input
else:
return self._output_type(*_input)
return ensure_tuple(_output, cls=(self.output_type or tuple)) if NOT_MODIFIED in tokens:
return ensure_tuple(_input, cls=(self.output_type or tuple))
if INHERIT in tokens:
if self._output_type is None:
self._output_type = concat_types(self._input_type, self._input_length, self._output_type, len(_output))
_output = _input + ensure_tuple(_output)
return ensure_tuple(_output, cls=(self._output_type or tuple))
def _send(self, value, _control=False): def _send(self, value, _control=False):
""" """
@ -330,26 +334,44 @@ class NodeExecutionContext(BaseContext, WithStatistics):
def isflag(param): def isflag(param):
return isinstance(param, Token) and param in (NOT_MODIFIED, ) return isinstance(param, Flag)
def split_tokens(output): def split_token(output):
""" """
Split an output into token tuple, real output tuple. Split an output into token tuple, real output tuple.
:param output: :param output:
:return: tuple, tuple :return: tuple, tuple
""" """
if isinstance(output, Token):
# just a flag
return (output, ), ()
if not istuple(output): output = ensure_tuple(output)
# no flag
return (), (output, )
i = 0 flags, i, len_output, data_allowed = set(), 0, len(output), True
while isflag(output[i]): while i < len_output and isflag(output[i]):
if output[i].must_be_first and i:
raise ValueError('{} flag must be first.'.format(output[i]))
if i and output[i - 1].must_be_last:
raise ValueError('{} flag must be last.'.format(output[i - 1]))
if output[i] in flags:
raise ValueError('Duplicate flag {}.'.format(output[i]))
flags.add(output[i])
data_allowed &= output[i].allows_data
i += 1 i += 1
return output[:i], output[i:] output = output[i:]
if not data_allowed and len(output):
raise ValueError('Output data provided after a flag that does not allow data.')
return flags, output
def concat_types(t1, l1, t2, l2):
t1, t2 = t1 or tuple, t2 or tuple
if t1 == t2 == tuple:
return tuple
f1 = t1._fields if hasattr(t1, '_fields') else tuple(range(l1))
f2 = t2._fields if hasattr(t2, '_fields') else tuple(range(l2))
return BagType('Inherited', f1 + f2)

View File

@ -1,11 +1,7 @@
import functools import functools
import html
import itertools import itertools
import operator
import pprint import pprint
from functools import reduce
from bonobo.util import ensure_tuple
from mondrian import term
from bonobo import settings from bonobo import settings
from bonobo.config import Configurable, Option, Method, use_raw_input, use_context, use_no_input from bonobo.config import Configurable, Option, Method, use_raw_input, use_context, use_no_input
@ -14,6 +10,7 @@ from bonobo.config.processors import ContextProcessor, use_context_processor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED
from bonobo.util.objects import ValueHolder from bonobo.util.objects import ValueHolder
from bonobo.util.term import CLEAR_EOL from bonobo.util.term import CLEAR_EOL
from mondrian import term
__all__ = [ __all__ = [
'FixedWindow', 'FixedWindow',
@ -94,29 +91,41 @@ class PrettyPrinter(Configurable):
@ContextProcessor @ContextProcessor
def context(self, context): def context(self, context):
context.setdefault('_jupyter_html', None)
yield context yield context
if context._jupyter_html is not None:
from IPython.display import display, HTML
display(HTML('\n'.join(['<table>'] + context._jupyter_html + ['</table>'])))
def __call__(self, context, *args, **kwargs): def __call__(self, context, *args, **kwargs):
quiet = settings.QUIET.get() if not settings.QUIET:
formater = self._format_quiet if quiet else self._format_console if term.isjupyter:
self.print_jupyter(context, *args, **kwargs)
if not quiet: return NOT_MODIFIED
print('\u250e' + '\u2500' * (self.max_width - 1)) if term.istty:
self.print_console(context, *args, **kwargs)
for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())):
if self.filter(index, key, value):
print(formater(index, key, value, fields=context.get_input_fields()))
if not quiet:
print('\u2516' + '\u2500' * (self.max_width - 1))
return NOT_MODIFIED return NOT_MODIFIED
def _format_quiet(self, index, key, value, *, fields=None): self.print_quiet(context, *args, **kwargs)
return NOT_MODIFIED
def print_quiet(self, context, *args, **kwargs):
for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())):
if self.filter(index, key, value):
print(self.format_quiet(index, key, value, fields=context.get_input_fields()))
def format_quiet(self, index, key, value, *, fields=None):
# XXX should we implement argnames here ? # XXX should we implement argnames here ?
return ' '.join(((' ' if index else '-'), str(key), ':', str(value).strip())) return ' '.join(((' ' if index else '-'), str(key), ':', str(value).strip()))
def _format_console(self, index, key, value, *, fields=None): def print_console(self, context, *args, **kwargs):
print('\u250e' + '\u2500' * (self.max_width - 1))
for index, (key, value) in enumerate(itertools.chain(enumerate(args), kwargs.items())):
if self.filter(index, key, value):
print(self.format_console(index, key, value, fields=context.get_input_fields()))
print('\u2516' + '\u2500' * (self.max_width - 1))
def format_console(self, index, key, value, *, fields=None):
fields = fields or [] fields = fields or []
if not isinstance(key, str): if not isinstance(key, str):
if len(fields) >= key and str(key) != str(fields[key]): if len(fields) >= key and str(key) != str(fields[key]):
@ -136,6 +145,21 @@ class PrettyPrinter(Configurable):
).strip() ).strip()
return '{}{}{}'.format(prefix, repr_of_value.replace('\n', CLEAR_EOL + '\n'), CLEAR_EOL) return '{}{}{}'.format(prefix, repr_of_value.replace('\n', CLEAR_EOL + '\n'), CLEAR_EOL)
def print_jupyter(self, context, *args):
if not context._jupyter_html:
context._jupyter_html = [
'<thead><tr>',
*map('<th>{}</th>'.format, map(html.escape, map(str,
context.get_input_fields() or range(len(args))))),
'</tr></thead>',
]
context._jupyter_html += [
'<tr>',
*map('<td>{}</td>'.format, map(html.escape, map(repr, args))),
'</tr>',
]
@use_no_input @use_no_input
def noop(*args, **kwargs): def noop(*args, **kwargs):

View File

@ -46,6 +46,9 @@ class Setting:
def __eq__(self, other): def __eq__(self, other):
return self.get() == other return self.get() == other
def __bool__(self):
return bool(self.get())
def set(self, value): def set(self, value):
value = self.formatter(value) if self.formatter else value value = self.formatter(value) if self.formatter else value
if self.validator and not self.validator(value): if self.validator and not self.validator(value):

View File

@ -192,6 +192,4 @@ rst_epilog = """
.. |longversion| replace:: v.{version} .. |longversion| replace:: v.{version}
""".format( """.format(version=version, )
version=version,
)

View File

@ -3,8 +3,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from bonobo import Graph from bonobo import Graph
from bonobo.constants import EMPTY from bonobo.constants import EMPTY, NOT_MODIFIED, INHERIT
from bonobo.execution.contexts.node import NodeExecutionContext from bonobo.execution.contexts.node import NodeExecutionContext, split_token
from bonobo.execution.strategies import NaiveStrategy from bonobo.execution.strategies import NaiveStrategy
from bonobo.util.testing import BufferingNodeExecutionContext, BufferingGraphExecutionContext from bonobo.util.testing import BufferingNodeExecutionContext, BufferingGraphExecutionContext
@ -224,3 +224,35 @@ def test_node_lifecycle_with_kill():
ctx.stop() ctx.stop()
assert all((ctx.started, ctx.killed, ctx.stopped)) and not ctx.alive assert all((ctx.started, ctx.killed, ctx.stopped)) and not ctx.alive
def test_split_token():
assert split_token(('foo', 'bar')) == (set(), ('foo', 'bar'))
assert split_token(()) == (set(), ())
assert split_token('') == (set(), ('', ))
def test_split_token_duplicate():
with pytest.raises(ValueError):
split_token((NOT_MODIFIED, NOT_MODIFIED))
with pytest.raises(ValueError):
split_token((INHERIT, INHERIT))
with pytest.raises(ValueError):
split_token((INHERIT, NOT_MODIFIED, INHERIT))
def test_split_token_not_modified():
with pytest.raises(ValueError):
split_token((NOT_MODIFIED, 'foo', 'bar'))
with pytest.raises(ValueError):
split_token((NOT_MODIFIED, INHERIT))
with pytest.raises(ValueError):
split_token((INHERIT, NOT_MODIFIED))
assert split_token(NOT_MODIFIED) == ({NOT_MODIFIED}, ())
assert split_token((NOT_MODIFIED, )) == ({NOT_MODIFIED}, ())
def test_split_token_inherit():
assert split_token(INHERIT) == ({INHERIT}, ())
assert split_token((INHERIT, )) == ({INHERIT}, ())
assert split_token((INHERIT, 'foo', 'bar')) == ({INHERIT}, ('foo', 'bar'))

View File

@ -0,0 +1,27 @@
from bonobo.constants import INHERIT
from bonobo.util.testing import BufferingNodeExecutionContext
messages = [
('Hello', ),
('Goodbye', ),
]
def append(*args):
return INHERIT, '!'
def test_inherit():
with BufferingNodeExecutionContext(append) as context:
context.write_sync(*messages)
assert context.get_buffer() == list(map(lambda x: x + ('!', ), messages))
def test_inherit_bag_tuple():
with BufferingNodeExecutionContext(append) as context:
context.set_input_fields(['message'])
context.write_sync(*messages)
assert context.get_output_fields() == ('message', '0')
assert context.get_buffer() == list(map(lambda x: x + ('!', ), messages))