[nodes/io] Adds an output_format option to CsvReader (BC ok) for more flexibility.
This commit is contained in:
@ -3,6 +3,8 @@ import csv
|
||||
from bonobo.config import Option
|
||||
from bonobo.config.processors import ContextProcessor
|
||||
from bonobo.constants import NOT_MODIFIED
|
||||
from bonobo.errors import ConfigurationError, ValidationError
|
||||
from bonobo.structs import Bag
|
||||
from bonobo.util.objects import ValueHolder
|
||||
from .file import FileHandler, FileReader, FileWriter
|
||||
|
||||
@ -28,6 +30,14 @@ class CsvHandler(FileHandler):
|
||||
headers = Option(tuple)
|
||||
|
||||
|
||||
def validate_csv_output_format(v):
|
||||
if callable(v):
|
||||
return v
|
||||
if v in {'dict', 'kwargs'}:
|
||||
return v
|
||||
raise ValidationError('Unsupported format {!r}.'.format(v))
|
||||
|
||||
|
||||
class CsvReader(CsvHandler, FileReader):
|
||||
"""
|
||||
Reads a CSV and yield the values as dicts.
|
||||
@ -39,13 +49,23 @@ class CsvReader(CsvHandler, FileReader):
|
||||
"""
|
||||
|
||||
skip = Option(int, default=0)
|
||||
output_format = Option(validate_csv_output_format, default='dict')
|
||||
|
||||
@ContextProcessor
|
||||
def csv_headers(self, context, fs, file):
|
||||
yield ValueHolder(self.headers)
|
||||
|
||||
def get_output_formater(self):
|
||||
if callable(self.output_format):
|
||||
return self.output_format
|
||||
elif isinstance(self.output_format, str):
|
||||
return getattr(self, '_format_as_' + self.output_format)
|
||||
else:
|
||||
raise ConfigurationError('Unsupported format {!r} for {}.'.format(self.output_format, type(self).__name__))
|
||||
|
||||
def read(self, fs, file, headers):
|
||||
reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
|
||||
formater = self.get_output_formater()
|
||||
|
||||
if not headers.get():
|
||||
headers.set(next(reader))
|
||||
@ -60,7 +80,13 @@ class CsvReader(CsvHandler, FileReader):
|
||||
if len(row) != field_count:
|
||||
raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count, ))
|
||||
|
||||
yield dict(zip(headers.value, row))
|
||||
yield formater(headers.get(), row)
|
||||
|
||||
def _format_as_dict(self, headers, values):
|
||||
return dict(zip(headers, values))
|
||||
|
||||
def _format_as_kwargs(self, headers, values):
|
||||
return Bag(**dict(zip(headers, values)))
|
||||
|
||||
|
||||
class CsvWriter(CsvHandler, FileWriter):
|
||||
|
||||
@ -75,6 +75,14 @@ class Bag:
|
||||
|
||||
raise TypeError('Could not apply bag to {}.'.format(func_or_iter))
|
||||
|
||||
def get(self):
|
||||
"""
|
||||
Get a 2 element tuple of this bag's args and kwargs.
|
||||
|
||||
:return: tuple
|
||||
"""
|
||||
return self.args, self.kwargs
|
||||
|
||||
def extend(self, *args, **kwargs):
|
||||
return type(self)(*args, _parent=self, **kwargs)
|
||||
|
||||
|
||||
@ -55,3 +55,38 @@ def test_read_csv_from_file(tmpdir):
|
||||
'b': 'b bar',
|
||||
'c': 'c bar',
|
||||
}
|
||||
|
||||
|
||||
def test_read_csv_kwargs_output_formater(tmpdir):
|
||||
fs, filename = open_fs(tmpdir), 'input.csv'
|
||||
fs.open(filename, 'w').write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar')
|
||||
|
||||
reader = CsvReader(path=filename, delimiter=',', output_format='kwargs')
|
||||
|
||||
context = CapturingNodeExecutionContext(reader, services={'fs': fs})
|
||||
|
||||
context.start()
|
||||
context.write(BEGIN, Bag(), END)
|
||||
context.step()
|
||||
context.stop()
|
||||
|
||||
assert len(context.send.mock_calls) == 2
|
||||
|
||||
args0, kwargs0 = context.send.call_args_list[0]
|
||||
assert len(args0) == 1 and not len(kwargs0)
|
||||
args1, kwargs1 = context.send.call_args_list[1]
|
||||
assert len(args1) == 1 and not len(kwargs1)
|
||||
|
||||
_args, _kwargs = args0[0].get()
|
||||
assert not len(_args) and _kwargs == {
|
||||
'a': 'a foo',
|
||||
'b': 'b foo',
|
||||
'c': 'c foo',
|
||||
}
|
||||
|
||||
_args, _kwargs = args1[0].get()
|
||||
assert not len(_args) and _kwargs == {
|
||||
'a': 'a bar',
|
||||
'b': 'b bar',
|
||||
'c': 'c bar',
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user