diff --git a/bonobo/nodes/io/csv.py b/bonobo/nodes/io/csv.py index 45b40de..e0412fa 100644 --- a/bonobo/nodes/io/csv.py +++ b/bonobo/nodes/io/csv.py @@ -3,6 +3,8 @@ import csv from bonobo.config import Option from bonobo.config.processors import ContextProcessor from bonobo.constants import NOT_MODIFIED +from bonobo.errors import ConfigurationError, ValidationError +from bonobo.structs import Bag from bonobo.util.objects import ValueHolder from .file import FileHandler, FileReader, FileWriter @@ -28,6 +30,14 @@ class CsvHandler(FileHandler): headers = Option(tuple) +def validate_csv_output_format(v): + if callable(v): + return v + if v in {'dict', 'kwargs'}: + return v + raise ValidationError('Unsupported format {!r}.'.format(v)) + + class CsvReader(CsvHandler, FileReader): """ Reads a CSV and yield the values as dicts. @@ -39,13 +49,23 @@ class CsvReader(CsvHandler, FileReader): """ skip = Option(int, default=0) + output_format = Option(validate_csv_output_format, default='dict') @ContextProcessor def csv_headers(self, context, fs, file): yield ValueHolder(self.headers) + def get_output_formater(self): + if callable(self.output_format): + return self.output_format + elif isinstance(self.output_format, str): + return getattr(self, '_format_as_' + self.output_format) + else: + raise ConfigurationError('Unsupported format {!r} for {}.'.format(self.output_format, type(self).__name__)) + def read(self, fs, file, headers): reader = csv.reader(file, delimiter=self.delimiter, quotechar=self.quotechar) + formater = self.get_output_formater() if not headers.get(): headers.set(next(reader)) @@ -60,7 +80,13 @@ class CsvReader(CsvHandler, FileReader): if len(row) != field_count: raise ValueError('Got a line with %d fields, expecting %d.' % (len(row), field_count, )) - yield dict(zip(headers.value, row)) + yield formater(headers.get(), row) + + def _format_as_dict(self, headers, values): + return dict(zip(headers, values)) + + def _format_as_kwargs(self, headers, values): + return Bag(**dict(zip(headers, values))) class CsvWriter(CsvHandler, FileWriter): diff --git a/bonobo/structs/bags.py b/bonobo/structs/bags.py index 5fec1f2..4ef2fa7 100644 --- a/bonobo/structs/bags.py +++ b/bonobo/structs/bags.py @@ -75,6 +75,14 @@ class Bag: raise TypeError('Could not apply bag to {}.'.format(func_or_iter)) + def get(self): + """ + Get a 2 element tuple of this bag's args and kwargs. + + :return: tuple + """ + return self.args, self.kwargs + def extend(self, *args, **kwargs): return type(self)(*args, _parent=self, **kwargs) diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index 59f7197..bded111 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -55,3 +55,38 @@ def test_read_csv_from_file(tmpdir): 'b': 'b bar', 'c': 'c bar', } + + +def test_read_csv_kwargs_output_formater(tmpdir): + fs, filename = open_fs(tmpdir), 'input.csv' + fs.open(filename, 'w').write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar') + + reader = CsvReader(path=filename, delimiter=',', output_format='kwargs') + + context = CapturingNodeExecutionContext(reader, services={'fs': fs}) + + context.start() + context.write(BEGIN, Bag(), END) + context.step() + context.stop() + + assert len(context.send.mock_calls) == 2 + + args0, kwargs0 = context.send.call_args_list[0] + assert len(args0) == 1 and not len(kwargs0) + args1, kwargs1 = context.send.call_args_list[1] + assert len(args1) == 1 and not len(kwargs1) + + _args, _kwargs = args0[0].get() + assert not len(_args) and _kwargs == { + 'a': 'a foo', + 'b': 'b foo', + 'c': 'c foo', + } + + _args, _kwargs = args1[0].get() + assert not len(_args) and _kwargs == { + 'a': 'a bar', + 'b': 'b bar', + 'c': 'c bar', + }