[core] I/O formats allowing both arg0 formating and kwargs based. Starting with 0.4, kwargs based will be default (BC break here, but needed for the greater good).

This commit is contained in:
Romain Dorgueil
2017-06-05 11:38:11 +02:00
parent c34b86872f
commit e5483de344
8 changed files with 111 additions and 82 deletions

View File

@ -2,8 +2,8 @@ import bonobo
from bonobo.commands.run import get_default_services from bonobo.commands.run import get_default_services
graph = bonobo.Graph( graph = bonobo.Graph(
bonobo.CsvReader('datasets/coffeeshops.txt'), bonobo.CsvReader('datasets/coffeeshops.txt', headers=('item',)),
print, bonobo.PrettyPrinter(),
) )
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,15 +1,16 @@
import bonobo import bonobo
from bonobo import Bag
from bonobo.commands.run import get_default_services from bonobo.commands.run import get_default_services
def get_fields(row): def get_fields(**row):
return row['fields'] return Bag(**row['fields'])
graph = bonobo.Graph( graph = bonobo.Graph(
bonobo.JsonReader('datasets/theaters.json'), bonobo.JsonReader('datasets/theaters.json'),
get_fields, get_fields,
bonobo.PrettyPrint(title_keys=('eq_nom_equipement', )), bonobo.PrettyPrinter(),
) )
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,28 +1,3 @@
import bonobo
from fs.tarfs import TarFS
import os
def cleanse_sms(row):
if row['category'] == 'spam':
row['sms_clean'] = '**MARKED AS SPAM** ' + row['sms'][0:50] + (
'...' if len(row['sms']) > 50 else ''
)
else:
row['sms_clean'] = row['sms']
return row['sms_clean']
graph = bonobo.Graph(
bonobo.PickleReader('spam.pkl'
), # spam.pkl is within the gzipped tarball
cleanse_sms,
print
)
if __name__ == '__main__':
''' '''
This example shows how a different file system service can be injected This example shows how a different file system service can be injected
into a transformation (as compressing pickled objects often makes sense into a transformation (as compressing pickled objects often makes sense
@ -49,13 +24,41 @@ if __name__ == '__main__':
The transformation (1) reads the pickled data, (2) marks and shortens The transformation (1) reads the pickled data, (2) marks and shortens
messages categorized as spam, and (3) prints the output. messages categorized as spam, and (3) prints the output.
''' '''
services = { import bonobo
from bonobo.commands.run import get_default_services
from fs.tarfs import TarFS
def cleanse_sms(**row):
if row['category'] == 'spam':
row['sms_clean'] = '**MARKED AS SPAM** ' + row['sms'][0:50] + (
'...' if len(row['sms']) > 50 else ''
)
else:
row['sms_clean'] = row['sms']
return row['sms_clean']
graph = bonobo.Graph(
# spam.pkl is within the gzipped tarball
bonobo.PickleReader('spam.pkl'),
cleanse_sms,
bonobo.PrettyPrinter(),
)
def get_services():
return {
'fs': 'fs':
TarFS( TarFS(
os.path. bonobo.get_examples_path('datasets/spam.tgz')
join(bonobo.get_examples_path(), 'datasets', 'spam.tgz')
) )
} }
bonobo.run(graph, services=services)
if __name__ == '__main__':
bonobo.run(graph, services=get_default_services(__file__))

View File

@ -3,10 +3,8 @@ import csv
from bonobo.config import Option from bonobo.config import Option
from bonobo.config.processors import ContextProcessor from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED
from bonobo.errors import ConfigurationError, ValidationError from bonobo.nodes.io.file import FileHandler, FileReader, FileWriter
from bonobo.structs import Bag
from bonobo.util.objects import ValueHolder from bonobo.util.objects import ValueHolder
from .file import FileHandler, FileReader, FileWriter
class CsvHandler(FileHandler): class CsvHandler(FileHandler):
@ -30,14 +28,6 @@ class CsvHandler(FileHandler):
headers = Option(tuple) headers = Option(tuple)
def validate_csv_output_format(v):
if callable(v):
return v
if v in {'dict', 'kwargs'}:
return v
raise ValidationError('Unsupported format {!r}.'.format(v))
class CsvReader(CsvHandler, FileReader): class CsvReader(CsvHandler, FileReader):
""" """
Reads a CSV and yield the values as dicts. Reads a CSV and yield the values as dicts.
@ -49,26 +39,17 @@ class CsvReader(CsvHandler, FileReader):
""" """
skip = Option(int, default=0) skip = Option(int, default=0)
output_format = Option(validate_csv_output_format, default='dict')
@ContextProcessor @ContextProcessor
def csv_headers(self, context, fs, file): def csv_headers(self, context, fs, file):
yield ValueHolder(self.headers) yield ValueHolder(self.headers)
def get_output_formater(self):
if callable(self.output_format):
return self.output_format
elif isinstance(self.output_format, str):
return getattr(self, '_format_as_' + self.output_format)
else:
raise ConfigurationError('Unsupported format {!r} for {}.'.format(self.output_format, type(self).__name__))
def read(self, fs, file, headers): def read(self, fs, file, headers):
reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar) reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
formater = self.get_output_formater()
if not headers.get(): if not headers.get():
headers.set(next(reader)) headers.set(next(reader))
_headers = headers.get()
field_count = len(headers) field_count = len(headers)
@ -80,13 +61,7 @@ class CsvReader(CsvHandler, FileReader):
if len(row) != field_count: if len(row) != field_count:
raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count,)) raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count,))
yield formater(headers.get(), row) yield self.get_output(dict(zip(_headers, row)))
def _format_as_dict(self, headers, values):
return dict(zip(headers, values))
def _format_as_kwargs(self, headers, values):
return Bag(**dict(zip(headers, values)))
class CsvWriter(CsvHandler, FileWriter): class CsvWriter(CsvHandler, FileWriter):
@ -96,7 +71,8 @@ class CsvWriter(CsvHandler, FileWriter):
headers = ValueHolder(list(self.headers) if self.headers else None) headers = ValueHolder(list(self.headers) if self.headers else None)
yield writer, headers yield writer, headers
def write(self, fs, file, lineno, writer, headers, row): def write(self, fs, file, lineno, writer, headers, *args, **kwargs):
row = self.get_input(*args, **kwargs)
if not lineno: if not lineno:
headers.set(headers.value or row.keys()) headers.set(headers.value or row.keys())
writer.writerow(headers.get()) writer.writerow(headers.get())

View File

@ -1,7 +1,9 @@
from bonobo import settings
from bonobo.config import Option, Service from bonobo.config import Option, Service
from bonobo.config.configurables import Configurable from bonobo.config.configurables import Configurable
from bonobo.config.processors import ContextProcessor from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED
from bonobo.structs.bags import Bag
from bonobo.util.objects import ValueHolder from bonobo.util.objects import ValueHolder
@ -22,6 +24,8 @@ class FileHandler(Configurable):
fs = Service('fs') # type: str fs = Service('fs') # type: str
ioformat = Option(settings.validate_io_format, default=settings.IOFORMAT)
@ContextProcessor @ContextProcessor
def file(self, context, fs): def file(self, context, fs):
with self.open(fs) as file: with self.open(fs) as file:
@ -30,15 +34,35 @@ class FileHandler(Configurable):
def open(self, fs): def open(self, fs):
return fs.open(self.path, self.mode, encoding=self.encoding) return fs.open(self.path, self.mode, encoding=self.encoding)
def get_input(self, *args, **kwargs):
if self.ioformat == settings.IOFORMAT_ARG0:
assert len(args) == 1 and not len(kwargs), 'ARG0 format implies one arg and no kwargs.'
return args[0]
if self.ioformat == settings.IOFORMAT_KWARGS:
assert len(args) == 0 and len(kwargs), 'KWARGS format implies no arg.'
return kwargs
raise NotImplementedError('Unsupported format.')
def get_output(self, row):
if self.ioformat == settings.IOFORMAT_ARG0:
return row
if self.ioformat == settings.IOFORMAT_KWARGS:
return Bag(**row)
raise NotImplementedError('Unsupported format.')
class Reader(FileHandler): class Reader(FileHandler):
"""Abstract component factory for readers. """Abstract component factory for readers.
""" """
def __call__(self, *args): def __call__(self, *args, **kwargs):
yield from self.read(*args) yield from self.read(*args, **kwargs)
def read(self, *args): def read(self, *args, **kwargs):
raise NotImplementedError('Abstract.') raise NotImplementedError('Abstract.')
@ -46,10 +70,10 @@ class Writer(FileHandler):
"""Abstract component factory for writers. """Abstract component factory for writers.
""" """
def __call__(self, *args): def __call__(self, *args, **kwargs):
return self.write(*args) return self.write(*args)
def write(self, *args): def write(self, *args, **kwargs):
raise NotImplementedError('Abstract.') raise NotImplementedError('Abstract.')

View File

@ -14,7 +14,7 @@ class JsonReader(JsonHandler, FileReader):
def read(self, fs, file): def read(self, fs, file):
for line in self.loader(file): for line in self.loader(file):
yield line yield self.get_output(line)
class JsonWriter(JsonHandler, FileWriter): class JsonWriter(JsonHandler, FileWriter):

View File

@ -53,7 +53,7 @@ class PickleReader(PickleHandler, FileReader):
if len(i) != item_count: if len(i) != item_count:
raise ValueError('Received an object with %d items, expecting %d.' % (len(i), item_count, )) raise ValueError('Received an object with %d items, expecting %d.' % (len(i), item_count, ))
yield dict(zip(i)) if is_dict else dict(zip(pickle_headers.value, i)) yield self.get_output(dict(zip(i)) if is_dict else dict(zip(pickle_headers.value, i)))
class PickleWriter(PickleHandler, FileWriter): class PickleWriter(PickleHandler, FileWriter):

View File

@ -2,6 +2,8 @@ import os
import logging import logging
from bonobo.errors import ValidationError
def to_bool(s): def to_bool(s):
if len(s): if len(s):
@ -23,7 +25,30 @@ QUIET = to_bool(os.environ.get('QUIET', 'f'))
# Logging level. # Logging level.
LOGGING_LEVEL = logging.DEBUG if DEBUG else logging.INFO LOGGING_LEVEL = logging.DEBUG if DEBUG else logging.INFO
# Input/Output format for transformations
IOFORMAT_ARG0 = 'arg0'
IOFORMAT_KWARGS = 'kwargs'
IOFORMATS = {
IOFORMAT_ARG0,
IOFORMAT_KWARGS,
}
IOFORMAT = os.environ.get('IOFORMAT', IOFORMAT_KWARGS)
def validate_io_format(v):
if callable(v):
return v
if v in IOFORMATS:
return v
raise ValidationError('Unsupported format {!r}.'.format(v))
def check(): def check():
if DEBUG and QUIET: if DEBUG and QUIET:
raise RuntimeError('I cannot be verbose and quiet at the same time.') raise RuntimeError('I cannot be verbose and quiet at the same time.')
if IOFORMAT not in IOFORMATS:
raise RuntimeError('Invalid default input/output format.')