[nodes/io] Adds an output_format option to CsvReader (BC ok) for more flexibility.

This commit is contained in:
Romain Dorgueil
2017-05-25 15:14:51 +02:00
parent 1c500b31cf
commit 9375a5d1fb
3 changed files with 70 additions and 1 deletions

View File

@ -3,6 +3,8 @@ import csv
from bonobo.config import Option
from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED
from bonobo.errors import ConfigurationError, ValidationError
from bonobo.structs import Bag
from bonobo.util.objects import ValueHolder
from .file import FileHandler, FileReader, FileWriter
@ -28,6 +30,14 @@ class CsvHandler(FileHandler):
headers = Option(tuple)
def validate_csv_output_format(v):
if callable(v):
return v
if v in {'dict', 'kwargs'}:
return v
raise ValidationError('Unsupported format {!r}.'.format(v))
class CsvReader(CsvHandler, FileReader):
"""
Reads a CSV and yield the values as dicts.
@ -39,13 +49,23 @@ class CsvReader(CsvHandler, FileReader):
"""
skip = Option(int, default=0)
output_format = Option(validate_csv_output_format, default='dict')
@ContextProcessor
def csv_headers(self, context, fs, file):
yield ValueHolder(self.headers)
def get_output_formater(self):
if callable(self.output_format):
return self.output_format
elif isinstance(self.output_format, str):
return getattr(self, '_format_as_' + self.output_format)
else:
raise ConfigurationError('Unsupported format {!r} for {}.'.format(self.output_format, type(self).__name__))
def read(self, fs, file, headers):
reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
formater = self.get_output_formater()
if not headers.get():
headers.set(next(reader))
@ -60,7 +80,13 @@ class CsvReader(CsvHandler, FileReader):
if len(row) != field_count:
raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count, ))
yield dict(zip(headers.value, row))
yield formater(headers.get(), row)
def _format_as_dict(self, headers, values):
return dict(zip(headers, values))
def _format_as_kwargs(self, headers, values):
return Bag(**dict(zip(headers, values)))
class CsvWriter(CsvHandler, FileWriter):

View File

@ -75,6 +75,14 @@ class Bag:
raise TypeError('Could not apply bag to {}.'.format(func_or_iter))
def get(self):
"""
Get a 2 element tuple of this bag's args and kwargs.
:return: tuple
"""
return self.args, self.kwargs
def extend(self, *args, **kwargs):
return type(self)(*args, _parent=self, **kwargs)

View File

@ -55,3 +55,38 @@ def test_read_csv_from_file(tmpdir):
'b': 'b bar',
'c': 'c bar',
}
def test_read_csv_kwargs_output_formater(tmpdir):
fs, filename = open_fs(tmpdir), 'input.csv'
fs.open(filename, 'w').write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar')
reader = CsvReader(path=filename, delimiter=',', output_format='kwargs')
context = CapturingNodeExecutionContext(reader, services={'fs': fs})
context.start()
context.write(BEGIN, Bag(), END)
context.step()
context.stop()
assert len(context.send.mock_calls) == 2
args0, kwargs0 = context.send.call_args_list[0]
assert len(args0) == 1 and not len(kwargs0)
args1, kwargs1 = context.send.call_args_list[1]
assert len(args1) == 1 and not len(kwargs1)
_args, _kwargs = args0[0].get()
assert not len(_args) and _kwargs == {
'a': 'a foo',
'b': 'b foo',
'c': 'c foo',
}
_args, _kwargs = args1[0].get()
assert not len(_args) and _kwargs == {
'a': 'a bar',
'b': 'b bar',
'c': 'c bar',
}