[nodes/io] Adds an output_format option to CsvReader (BC ok) for more flexibility.
This commit is contained in:
@ -3,6 +3,8 @@ import csv
|
|||||||
from bonobo.config import Option
|
from bonobo.config import Option
|
||||||
from bonobo.config.processors import ContextProcessor
|
from bonobo.config.processors import ContextProcessor
|
||||||
from bonobo.constants import NOT_MODIFIED
|
from bonobo.constants import NOT_MODIFIED
|
||||||
|
from bonobo.errors import ConfigurationError, ValidationError
|
||||||
|
from bonobo.structs import Bag
|
||||||
from bonobo.util.objects import ValueHolder
|
from bonobo.util.objects import ValueHolder
|
||||||
from .file import FileHandler, FileReader, FileWriter
|
from .file import FileHandler, FileReader, FileWriter
|
||||||
|
|
||||||
@ -28,6 +30,14 @@ class CsvHandler(FileHandler):
|
|||||||
headers = Option(tuple)
|
headers = Option(tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_csv_output_format(v):
|
||||||
|
if callable(v):
|
||||||
|
return v
|
||||||
|
if v in {'dict', 'kwargs'}:
|
||||||
|
return v
|
||||||
|
raise ValidationError('Unsupported format {!r}.'.format(v))
|
||||||
|
|
||||||
|
|
||||||
class CsvReader(CsvHandler, FileReader):
|
class CsvReader(CsvHandler, FileReader):
|
||||||
"""
|
"""
|
||||||
Reads a CSV and yield the values as dicts.
|
Reads a CSV and yield the values as dicts.
|
||||||
@ -39,13 +49,23 @@ class CsvReader(CsvHandler, FileReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
skip = Option(int, default=0)
|
skip = Option(int, default=0)
|
||||||
|
output_format = Option(validate_csv_output_format, default='dict')
|
||||||
|
|
||||||
@ContextProcessor
|
@ContextProcessor
|
||||||
def csv_headers(self, context, fs, file):
|
def csv_headers(self, context, fs, file):
|
||||||
yield ValueHolder(self.headers)
|
yield ValueHolder(self.headers)
|
||||||
|
|
||||||
|
def get_output_formater(self):
|
||||||
|
if callable(self.output_format):
|
||||||
|
return self.output_format
|
||||||
|
elif isinstance(self.output_format, str):
|
||||||
|
return getattr(self, '_format_as_' + self.output_format)
|
||||||
|
else:
|
||||||
|
raise ConfigurationError('Unsupported format {!r} for {}.'.format(self.output_format, type(self).__name__))
|
||||||
|
|
||||||
def read(self, fs, file, headers):
|
def read(self, fs, file, headers):
|
||||||
reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
|
reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar)
|
||||||
|
formater = self.get_output_formater()
|
||||||
|
|
||||||
if not headers.get():
|
if not headers.get():
|
||||||
headers.set(next(reader))
|
headers.set(next(reader))
|
||||||
@ -60,7 +80,13 @@ class CsvReader(CsvHandler, FileReader):
|
|||||||
if len(row) != field_count:
|
if len(row) != field_count:
|
||||||
raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count, ))
|
raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count, ))
|
||||||
|
|
||||||
yield dict(zip(headers.value, row))
|
yield formater(headers.get(), row)
|
||||||
|
|
||||||
|
def _format_as_dict(self, headers, values):
|
||||||
|
return dict(zip(headers, values))
|
||||||
|
|
||||||
|
def _format_as_kwargs(self, headers, values):
|
||||||
|
return Bag(**dict(zip(headers, values)))
|
||||||
|
|
||||||
|
|
||||||
class CsvWriter(CsvHandler, FileWriter):
|
class CsvWriter(CsvHandler, FileWriter):
|
||||||
|
|||||||
@ -75,6 +75,14 @@ class Bag:
|
|||||||
|
|
||||||
raise TypeError('Could not apply bag to {}.'.format(func_or_iter))
|
raise TypeError('Could not apply bag to {}.'.format(func_or_iter))
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Get a 2 element tuple of this bag's args and kwargs.
|
||||||
|
|
||||||
|
:return: tuple
|
||||||
|
"""
|
||||||
|
return self.args, self.kwargs
|
||||||
|
|
||||||
def extend(self, *args, **kwargs):
|
def extend(self, *args, **kwargs):
|
||||||
return type(self)(*args, _parent=self, **kwargs)
|
return type(self)(*args, _parent=self, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -55,3 +55,38 @@ def test_read_csv_from_file(tmpdir):
|
|||||||
'b': 'b bar',
|
'b': 'b bar',
|
||||||
'c': 'c bar',
|
'c': 'c bar',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_csv_kwargs_output_formater(tmpdir):
|
||||||
|
fs, filename = open_fs(tmpdir), 'input.csv'
|
||||||
|
fs.open(filename, 'w').write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar')
|
||||||
|
|
||||||
|
reader = CsvReader(path=filename, delimiter=',', output_format='kwargs')
|
||||||
|
|
||||||
|
context = CapturingNodeExecutionContext(reader, services={'fs': fs})
|
||||||
|
|
||||||
|
context.start()
|
||||||
|
context.write(BEGIN, Bag(), END)
|
||||||
|
context.step()
|
||||||
|
context.stop()
|
||||||
|
|
||||||
|
assert len(context.send.mock_calls) == 2
|
||||||
|
|
||||||
|
args0, kwargs0 = context.send.call_args_list[0]
|
||||||
|
assert len(args0) == 1 and not len(kwargs0)
|
||||||
|
args1, kwargs1 = context.send.call_args_list[1]
|
||||||
|
assert len(args1) == 1 and not len(kwargs1)
|
||||||
|
|
||||||
|
_args, _kwargs = args0[0].get()
|
||||||
|
assert not len(_args) and _kwargs == {
|
||||||
|
'a': 'a foo',
|
||||||
|
'b': 'b foo',
|
||||||
|
'c': 'c foo',
|
||||||
|
}
|
||||||
|
|
||||||
|
_args, _kwargs = args1[0].get()
|
||||||
|
assert not len(_args) and _kwargs == {
|
||||||
|
'a': 'a bar',
|
||||||
|
'b': 'b bar',
|
||||||
|
'c': 'c bar',
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user