Refactoring and fixes around ioformats.

This commit is contained in:
Romain Dorgueil
2017-06-08 21:47:01 +02:00
parent d19178a28e
commit 1ca48d885d
18 changed files with 69 additions and 49 deletions

View File

@ -6,6 +6,7 @@ DEFAULT_SERVICES_ATTR = 'get_services'
DEFAULT_GRAPH_FILENAME = '__main__.py'
DEFAULT_GRAPH_ATTR = 'get_graph'
def get_default_services(filename, services=None):
dirname = os.path.dirname(filename)
services_filename = os.path.join(dirname, DEFAULT_SERVICES_FILENAME)

View File

@ -2,7 +2,7 @@ import bonobo
from bonobo.commands.run import get_default_services
graph = bonobo.Graph(
bonobo.CsvReader('datasets/coffeeshops.txt', headers=('item',)),
bonobo.CsvReader('datasets/coffeeshops.txt', headers=('item', )),
bonobo.PrettyPrinter(),
)

View File

@ -52,12 +52,7 @@ graph = bonobo.Graph(
def get_services():
return {
'fs':
TarFS(
bonobo.get_examples_path('datasets/spam.tgz')
)
}
return {'fs': TarFS(bonobo.get_examples_path('datasets/spam.tgz'))}
if __name__ == '__main__':

View File

@ -1,7 +1,11 @@
import bonobo
graph = bonobo.Graph(
['foo', 'bar', 'baz', ],
[
'foo',
'bar',
'baz',
],
str.upper,
print,
)

View File

@ -59,7 +59,7 @@ class CsvReader(CsvHandler, FileReader):
for row in reader:
if len(row) != field_count:
raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count,))
raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count, ))
yield self.get_output(dict(zip(_headers, row)))

View File

@ -21,10 +21,8 @@ class FileHandler(Configurable):
eol = Option(str, default='\n') # type: str
mode = Option(str) # type: str
encoding = Option(str, default='utf-8') # type: str
fs = Service('fs') # type: str
ioformat = Option(settings.validate_io_format, default=settings.IOFORMAT)
ioformat = Option(default=settings.IOFORMAT.get)
@ContextProcessor
def file(self, context, fs):

View File

@ -1,6 +1,5 @@
import os
import logging
import os
from bonobo.errors import ValidationError
@ -13,6 +12,36 @@ def to_bool(s):
return False
class Setting:
def __init__(self, name, default=None, validator=None):
self.name = name
if default:
self.default = default if callable(default) else lambda: default
else:
self.default = lambda: None
if validator:
self.validator = validator
else:
self.validator = None
def __repr__(self):
return '<Setting {}={!r}>'.format(self.name, self.value)
def set(self, value):
if self.validator and not self.validator(value):
raise ValidationError('Invalid value {!r} for setting {}.'.format(value, self.name))
self.value = value
def get(self):
try:
return self.value
except AttributeError:
self.value = self.default()
return self.value
# Debug/verbose mode.
DEBUG = to_bool(os.environ.get('DEBUG', 'f'))
@ -34,21 +63,9 @@ IOFORMATS = {
IOFORMAT_KWARGS,
}
IOFORMAT = os.environ.get('IOFORMAT', IOFORMAT_KWARGS)
def validate_io_format(v):
if callable(v):
return v
if v in IOFORMATS:
return v
raise ValidationError('Unsupported format {!r}.'.format(v))
IOFORMAT = Setting('IOFORMAT', default=IOFORMAT_KWARGS, validator=IOFORMATS.__contains__)
def check():
if DEBUG and QUIET:
raise RuntimeError('I cannot be verbose and quiet at the same time.')
if IOFORMAT not in IOFORMATS:
raise RuntimeError('Invalid default input/output format.')

View File

@ -21,7 +21,7 @@ def force_iterator(mixed):
def ensure_tuple(tuple_or_mixed):
if isinstance(tuple_or_mixed, tuple):
return tuple_or_mixed
return (tuple_or_mixed,)
return (tuple_or_mixed, )
def tuplize(generator):