[stdlib] Fix I/O related nodes (especially json), there were bad bugs with ioformat.

This commit is contained in:
Romain Dorgueil
2017-06-17 10:37:17 +02:00
parent 3c4010f9c3
commit 67b4227436
13 changed files with 243 additions and 219 deletions

9
.gitignore vendored
View File

@ -2,6 +2,7 @@
*,cover *,cover
*.egg *.egg
*.egg-info/ *.egg-info/
*.iml
*.log *.log
*.manifest *.manifest
*.mo *.mo
@ -20,25 +21,17 @@
.installed.cfg .installed.cfg
.ipynb_checkpoints .ipynb_checkpoints
.python-version .python-version
.tox/
.webassets-cache
/.idea /.idea
/.release /.release
/bonobo.iml
/bonobo/examples/work_in_progress/ /bonobo/examples/work_in_progress/
/bonobo/ext/jupyter/js/node_modules/ /bonobo/ext/jupyter/js/node_modules/
/build/ /build/
/coverage.xml /coverage.xml
/develop-eggs/
/dist/ /dist/
/docs/_build/ /docs/_build/
/downloads/
/eggs/ /eggs/
/examples/private /examples/private
/htmlcov/
/sdist/ /sdist/
/tags /tags
celerybeat-schedule
parts/
pip-delete-this-directory.txt pip-delete-this-directory.txt
pip-log.txt pip-log.txt

View File

@ -70,9 +70,9 @@ def set_level(level):
def get_logger(name='bonobo'): def get_logger(name='bonobo'):
return logging.getLogger(name) return logging.getLogger(name)
# Compatibility with python logging # Compatibility with python logging
getLogger = get_logger getLogger = get_logger
# Setup formating and level. # Setup formating and level.
setup(level=settings.LOGGING_LEVEL) setup(level=settings.LOGGING_LEVEL)

82
bonobo/nodes/io/base.py Normal file
View File

@ -0,0 +1,82 @@
from bonobo import settings
from bonobo.config import Configurable, ContextProcessor, Option, Service
from bonobo.structs.bags import Bag
class IOFormatEnabled(Configurable):
ioformat = Option(default=settings.IOFORMAT.get)
def get_input(self, *args, **kwargs):
if self.ioformat == settings.IOFORMAT_ARG0:
if len(args) != 1 or len(kwargs):
raise ValueError(
'Wrong input formating: IOFORMAT=ARG0 implies one arg and no kwargs, got args={!r} and kwargs={!r}.'.
format(args, kwargs)
)
return args[0]
if self.ioformat == settings.IOFORMAT_KWARGS:
if len(args) or not len(kwargs):
raise ValueError(
'Wrong input formating: IOFORMAT=KWARGS ioformat implies no arg, got args={!r} and kwargs={!r}.'.
format(args, kwargs)
)
return kwargs
raise NotImplementedError('Unsupported format.')
def get_output(self, row):
if self.ioformat == settings.IOFORMAT_ARG0:
return row
if self.ioformat == settings.IOFORMAT_KWARGS:
return Bag(**row)
raise NotImplementedError('Unsupported format.')
class FileHandler(Configurable):
"""Abstract component factory for file-related components.
Args:
path (str): which path to use within the provided filesystem.
eol (str): which character to use to separate lines.
mode (str): which mode to use when opening the file.
fs (str): service name to use for filesystem.
"""
path = Option(str, required=True, positional=True) # type: str
eol = Option(str, default='\n') # type: str
mode = Option(str) # type: str
encoding = Option(str, default='utf-8') # type: str
fs = Service('fs') # type: str
@ContextProcessor
def file(self, context, fs):
with self.open(fs) as file:
yield file
def open(self, fs):
return fs.open(self.path, self.mode, encoding=self.encoding)
class Reader:
"""Abstract component factory for readers.
"""
def __call__(self, *args, **kwargs):
yield from self.read(*args, **kwargs)
def read(self, *args, **kwargs):
raise NotImplementedError('Abstract.')
class Writer:
"""Abstract component factory for writers.
"""
def __call__(self, *args, **kwargs):
return self.write(*args, **kwargs)
def write(self, *args, **kwargs):
raise NotImplementedError('Abstract.')

View File

@ -3,7 +3,8 @@ import csv
from bonobo.config import Option from bonobo.config import Option
from bonobo.config.processors import ContextProcessor from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED
from bonobo.nodes.io.file import FileHandler, FileReader, FileWriter from bonobo.nodes.io.file import FileReader, FileWriter
from bonobo.nodes.io.base import FileHandler, IOFormatEnabled
from bonobo.util.objects import ValueHolder from bonobo.util.objects import ValueHolder
@ -28,7 +29,7 @@ class CsvHandler(FileHandler):
headers = Option(tuple) headers = Option(tuple)
class CsvReader(CsvHandler, FileReader): class CsvReader(IOFormatEnabled, FileReader, CsvHandler):
""" """
Reads a CSV and yield the values as dicts. Reads a CSV and yield the values as dicts.
@ -64,7 +65,7 @@ class CsvReader(CsvHandler, FileReader):
yield self.get_output(dict(zip(_headers, row))) yield self.get_output(dict(zip(_headers, row)))
class CsvWriter(CsvHandler, FileWriter): class CsvWriter(IOFormatEnabled, FileWriter, CsvHandler):
@ContextProcessor @ContextProcessor
def writer(self, context, fs, file, lineno): def writer(self, context, fs, file, lineno):
writer = csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol) writer = csv.writer(file, delimiter=self.delimiter, quotechar=self.quotechar, lineterminator=self.eol)

View File

@ -1,81 +1,11 @@
from bonobo import settings from bonobo.config import Option
from bonobo.config import Option, Service
from bonobo.config.configurables import Configurable
from bonobo.config.processors import ContextProcessor from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED
from bonobo.structs.bags import Bag from bonobo.nodes.io.base import FileHandler, Reader, Writer
from bonobo.util.objects import ValueHolder from bonobo.util.objects import ValueHolder
class FileHandler(Configurable): class FileReader(Reader, FileHandler):
"""Abstract component factory for file-related components.
Args:
path (str): which path to use within the provided filesystem.
eol (str): which character to use to separate lines.
mode (str): which mode to use when opening the file.
fs (str): service name to use for filesystem.
"""
path = Option(str, required=True, positional=True) # type: str
eol = Option(str, default='\n') # type: str
mode = Option(str) # type: str
encoding = Option(str, default='utf-8') # type: str
fs = Service('fs') # type: str
ioformat = Option(default=settings.IOFORMAT.get)
@ContextProcessor
def file(self, context, fs):
with self.open(fs) as file:
yield file
def open(self, fs):
return fs.open(self.path, self.mode, encoding=self.encoding)
def get_input(self, *args, **kwargs):
if self.ioformat == settings.IOFORMAT_ARG0:
assert len(args) == 1 and not len(kwargs), 'ARG0 format implies one arg and no kwargs.'
return args[0]
if self.ioformat == settings.IOFORMAT_KWARGS:
assert len(args) == 0 and len(kwargs), 'KWARGS format implies no arg.'
return kwargs
raise NotImplementedError('Unsupported format.')
def get_output(self, row):
if self.ioformat == settings.IOFORMAT_ARG0:
return row
if self.ioformat == settings.IOFORMAT_KWARGS:
return Bag(**row)
raise NotImplementedError('Unsupported format.')
class Reader(FileHandler):
"""Abstract component factory for readers.
"""
def __call__(self, *args, **kwargs):
yield from self.read(*args, **kwargs)
def read(self, *args, **kwargs):
raise NotImplementedError('Abstract.')
class Writer(FileHandler):
"""Abstract component factory for writers.
"""
def __call__(self, *args, **kwargs):
return self.write(*args)
def write(self, *args, **kwargs):
raise NotImplementedError('Abstract.')
class FileReader(Reader):
"""Component factory for file-like readers. """Component factory for file-like readers.
On its own, it can be used to read a file and yield one row per line, trimming the "eol" character at the end if On its own, it can be used to read a file and yield one row per line, trimming the "eol" character at the end if
@ -93,7 +23,7 @@ class FileReader(Reader):
yield line.rstrip(self.eol) yield line.rstrip(self.eol)
class FileWriter(Writer): class FileWriter(Writer, FileHandler):
"""Component factory for file or file-like writers. """Component factory for file or file-like writers.
On its own, it can be used to write in a file one line per row that comes into this component. Extending it is On its own, it can be used to write in a file one line per row that comes into this component. Extending it is
@ -107,11 +37,11 @@ class FileWriter(Writer):
lineno = ValueHolder(0) lineno = ValueHolder(0)
yield lineno yield lineno
def write(self, fs, file, lineno, row): def write(self, fs, file, lineno, line):
""" """
Write a row on the next line of opened file in context. Write a row on the next line of opened file in context.
""" """
self._write_line(file, (self.eol if lineno.value else '') + row) self._write_line(file, (self.eol if lineno.value else '') + line)
lineno += 1 lineno += 1
return NOT_MODIFIED return NOT_MODIFIED

View File

@ -1,15 +1,17 @@
import json import json
from bonobo.config.processors import ContextProcessor from bonobo.config.processors import ContextProcessor
from bonobo.nodes.io.file import FileWriter, FileReader from bonobo.constants import NOT_MODIFIED
from bonobo.nodes.io.base import FileHandler, IOFormatEnabled
from bonobo.nodes.io.file import FileReader, FileWriter
class JsonHandler(): class JsonHandler(FileHandler):
eol = ',\n' eol = ',\n'
prefix, suffix = '[', ']' prefix, suffix = '[', ']'
class JsonReader(JsonHandler, FileReader): class JsonReader(IOFormatEnabled, FileReader, JsonHandler):
loader = staticmethod(json.load) loader = staticmethod(json.load)
def read(self, fs, file): def read(self, fs, file):
@ -17,18 +19,21 @@ class JsonReader(JsonHandler, FileReader):
yield self.get_output(line) yield self.get_output(line)
class JsonWriter(JsonHandler, FileWriter): class JsonWriter(IOFormatEnabled, FileWriter, JsonHandler):
@ContextProcessor @ContextProcessor
def envelope(self, context, fs, file, lineno): def envelope(self, context, fs, file, lineno):
file.write(self.prefix) file.write(self.prefix)
yield yield
file.write(self.suffix) file.write(self.suffix)
def write(self, fs, file, lineno, row): def write(self, fs, file, lineno, *args, **kwargs):
""" """
Write a json row on the next line of file pointed by ctx.file. Write a json row on the next line of file pointed by ctx.file.
:param ctx: :param ctx:
:param row: :param row:
""" """
return super().write(fs, file, lineno, json.dumps(row)) row = self.get_input(*args, **kwargs)
self._write_line(file, (self.eol if lineno.value else '') + json.dumps(row))
lineno += 1
return NOT_MODIFIED

View File

@ -1,10 +1,11 @@
import pickle import pickle
from bonobo.config.processors import ContextProcessor
from bonobo.config import Option from bonobo.config import Option
from bonobo.config.processors import ContextProcessor
from bonobo.constants import NOT_MODIFIED from bonobo.constants import NOT_MODIFIED
from bonobo.nodes.io.base import FileHandler, IOFormatEnabled
from bonobo.nodes.io.file import FileReader, FileWriter
from bonobo.util.objects import ValueHolder from bonobo.util.objects import ValueHolder
from .file import FileReader, FileWriter, FileHandler
class PickleHandler(FileHandler): class PickleHandler(FileHandler):
@ -19,7 +20,7 @@ class PickleHandler(FileHandler):
item_names = Option(tuple) item_names = Option(tuple)
class PickleReader(PickleHandler, FileReader): class PickleReader(IOFormatEnabled, FileReader, PickleHandler):
""" """
Reads a Python pickle object and yields the items in dicts. Reads a Python pickle object and yields the items in dicts.
""" """
@ -56,8 +57,7 @@ class PickleReader(PickleHandler, FileReader):
yield self.get_output(dict(zip(i)) if is_dict else dict(zip(pickle_headers.value, i))) yield self.get_output(dict(zip(i)) if is_dict else dict(zip(pickle_headers.value, i)))
class PickleWriter(PickleHandler, FileWriter): class PickleWriter(IOFormatEnabled, FileWriter, PickleHandler):
mode = Option(str, default='wb') mode = Option(str, default='wb')
def write(self, fs, file, lineno, item): def write(self, fs, file, lineno, item):

View File

@ -1,6 +1,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from unittest.mock import MagicMock from unittest.mock import MagicMock
from bonobo import open_fs
from bonobo.execution.node import NodeExecutionContext from bonobo.execution.node import NodeExecutionContext
@ -17,3 +18,20 @@ def optional_contextmanager(cm, *, ignore=False):
else: else:
with cm: with cm:
yield yield
class FilesystemTester:
def __init__(self, extension='txt', mode='w'):
self.extension = extension
self.input_data = ''
self.mode = mode
def get_services_for_reader(self, tmpdir):
fs, filename = open_fs(tmpdir), 'input.' + self.extension
with fs.open(filename, self.mode) as fp:
fp.write(self.input_data)
return fs, filename, {'fs': fs}
def get_services_for_writer(self, tmpdir):
fs, filename = open_fs(tmpdir), 'output.' + self.extension
return fs, filename, {'fs': fs}

View File

View File

@ -1,23 +1,21 @@
import pytest import pytest
from bonobo import Bag, CsvReader, CsvWriter, open_fs, settings from bonobo import Bag, CsvReader, CsvWriter, settings
from bonobo.constants import BEGIN, END from bonobo.constants import BEGIN, END
from bonobo.execution.node import NodeExecutionContext from bonobo.execution.node import NodeExecutionContext
from bonobo.util.testing import CapturingNodeExecutionContext from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester
csv_tester = FilesystemTester('csv')
csv_tester.input_data = 'a,b,c\na foo,b foo,c foo\na bar,b bar,c bar'
def test_write_csv_to_file(tmpdir): def test_write_csv_to_file_arg0(tmpdir):
fs, filename = open_fs(tmpdir), 'output.csv' fs, filename, services = csv_tester.get_services_for_writer(tmpdir)
writer = CsvWriter(path=filename, ioformat=settings.IOFORMAT_ARG0)
context = NodeExecutionContext(writer, services={'fs': fs})
with NodeExecutionContext(CsvWriter(path=filename, ioformat=settings.IOFORMAT_ARG0), services=services) as context:
context.write(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END) context.write(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END)
context.start()
context.step() context.step()
context.step() context.step()
context.stop()
with fs.open(filename) as fp: with fs.open(filename) as fp:
assert fp.read() == 'foo\nbar\nbaz\n' assert fp.read() == 'foo\nbar\nbaz\n'
@ -26,19 +24,33 @@ def test_write_csv_to_file(tmpdir):
getattr(context, 'file') getattr(context, 'file')
def test_read_csv_from_file(tmpdir): @pytest.mark.parametrize('add_kwargs', ({}, {
fs, filename = open_fs(tmpdir), 'input.csv' 'ioformat': settings.IOFORMAT_KWARGS,
with fs.open(filename, 'w') as fp: },))
fp.write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar') def test_write_csv_to_file_kwargs(tmpdir, add_kwargs):
fs, filename, services = csv_tester.get_services_for_writer(tmpdir)
reader = CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0) with NodeExecutionContext(CsvWriter(path=filename, **add_kwargs), services=services) as context:
context.write(BEGIN, Bag(**{'foo': 'bar'}), Bag(**{'foo': 'baz', 'ignore': 'this'}), END)
context.step()
context.step()
context = CapturingNodeExecutionContext(reader, services={'fs': fs}) with fs.open(filename) as fp:
assert fp.read() == 'foo\nbar\nbaz\n'
context.start() with pytest.raises(AttributeError):
getattr(context, 'file')
def test_read_csv_from_file_arg0(tmpdir):
fs, filename, services = csv_tester.get_services_for_reader(tmpdir)
with CapturingNodeExecutionContext(
CsvReader(path=filename, delimiter=',', ioformat=settings.IOFORMAT_ARG0),
services=services,
) as context:
context.write(BEGIN, Bag(), END) context.write(BEGIN, Bag(), END)
context.step() context.step()
context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2
@ -59,19 +71,15 @@ def test_read_csv_from_file(tmpdir):
} }
def test_read_csv_kwargs_output_formater(tmpdir): def test_read_csv_from_file_kwargs(tmpdir):
fs, filename = open_fs(tmpdir), 'input.csv' fs, filename, services = csv_tester.get_services_for_reader(tmpdir)
with fs.open(filename, 'w') as fp:
fp.write('a,b,c\na foo,b foo,c foo\na bar,b bar,c bar')
reader = CsvReader(path=filename, delimiter=',') with CapturingNodeExecutionContext(
CsvReader(path=filename, delimiter=','),
context = CapturingNodeExecutionContext(reader, services={'fs': fs}) services=services,
) as context:
context.start()
context.write(BEGIN, Bag(), END) context.write(BEGIN, Bag(), END)
context.step() context.step()
context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2

View File

@ -1,9 +1,22 @@
import pytest import pytest
from bonobo import Bag, FileReader, FileWriter, open_fs from bonobo import Bag, FileReader, FileWriter
from bonobo.constants import BEGIN, END from bonobo.constants import BEGIN, END
from bonobo.execution.node import NodeExecutionContext from bonobo.execution.node import NodeExecutionContext
from bonobo.util.testing import CapturingNodeExecutionContext from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester
txt_tester = FilesystemTester('txt')
txt_tester.input_data = 'Hello\nWorld\n'
def test_file_writer_contextless(tmpdir):
fs, filename, services = txt_tester.get_services_for_writer(tmpdir)
with FileWriter(path=filename).open(fs) as fp:
fp.write('Yosh!')
with fs.open(filename) as fp:
assert fp.read() == 'Yosh!'
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -14,46 +27,23 @@ from bonobo.util.testing import CapturingNodeExecutionContext
] ]
) )
def test_file_writer_in_context(tmpdir, lines, output): def test_file_writer_in_context(tmpdir, lines, output):
fs, filename = open_fs(tmpdir), 'output.txt' fs, filename, services = txt_tester.get_services_for_writer(tmpdir)
writer = FileWriter(path=filename) with NodeExecutionContext(FileWriter(path=filename), services=services) as context:
context = NodeExecutionContext(writer, services={'fs': fs})
context.start()
context.write(BEGIN, *map(Bag, lines), END) context.write(BEGIN, *map(Bag, lines), END)
for _ in range(len(lines)): for _ in range(len(lines)):
context.step() context.step()
context.stop()
with fs.open(filename) as fp: with fs.open(filename) as fp:
assert fp.read() == output assert fp.read() == output
def test_file_writer_out_of_context(tmpdir): def test_file_reader(tmpdir):
fs, filename = open_fs(tmpdir), 'output.txt' fs, filename, services = txt_tester.get_services_for_reader(tmpdir)
writer = FileWriter(path=filename) with CapturingNodeExecutionContext(FileReader(path=filename), services=services) as context:
with writer.open(fs) as fp:
fp.write('Yosh!')
with fs.open(filename) as fp:
assert fp.read() == 'Yosh!'
def test_file_reader_in_context(tmpdir):
fs, filename = open_fs(tmpdir), 'input.txt'
with fs.open(filename, 'w') as fp:
fp.write('Hello\nWorld\n')
reader = FileReader(path=filename)
context = CapturingNodeExecutionContext(reader, services={'fs': fs})
context.start()
context.write(BEGIN, Bag(), END) context.write(BEGIN, Bag(), END)
context.step() context.step()
context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2

View File

@ -1,44 +1,48 @@
import pytest import pytest
from bonobo import Bag, JsonReader, JsonWriter, open_fs, settings from bonobo import Bag, JsonReader, JsonWriter, settings
from bonobo.constants import BEGIN, END from bonobo.constants import BEGIN, END
from bonobo.execution.node import NodeExecutionContext from bonobo.execution.node import NodeExecutionContext
from bonobo.util.testing import CapturingNodeExecutionContext from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester
json_tester = FilesystemTester('json')
json_tester.input_data = '''[{"x": "foo"},{"x": "bar"}]'''
def test_write_json_to_file(tmpdir): def test_write_json_arg0(tmpdir):
fs, filename = open_fs(tmpdir), 'output.json' fs, filename, services = json_tester.get_services_for_writer(tmpdir)
writer = JsonWriter(filename, ioformat=settings.IOFORMAT_ARG0) with NodeExecutionContext(JsonWriter(filename, ioformat=settings.IOFORMAT_ARG0), services=services) as context:
context = NodeExecutionContext(writer, services={'fs': fs})
context.start()
context.write(BEGIN, Bag({'foo': 'bar'}), END) context.write(BEGIN, Bag({'foo': 'bar'}), END)
context.step() context.step()
context.stop()
with fs.open(filename) as fp: with fs.open(filename) as fp:
assert fp.read() == '[{"foo": "bar"}]' assert fp.read() == '[{"foo": "bar"}]'
with pytest.raises(AttributeError):
getattr(context, 'file')
with pytest.raises(AttributeError): @pytest.mark.parametrize('add_kwargs', ({}, {
getattr(context, 'first') 'ioformat': settings.IOFORMAT_KWARGS,
}, ))
def test_write_json_kwargs(tmpdir, add_kwargs):
fs, filename, services = json_tester.get_services_for_writer(tmpdir)
with NodeExecutionContext(JsonWriter(filename, **add_kwargs), services=services) as context:
context.write(BEGIN, Bag(**{'foo': 'bar'}), END)
context.step()
with fs.open(filename) as fp:
assert fp.read() == '[{"foo": "bar"}]'
def test_read_json_from_file(tmpdir): def test_read_json_arg0(tmpdir):
fs, filename = open_fs(tmpdir), 'input.json' fs, filename, services = json_tester.get_services_for_reader(tmpdir)
with fs.open(filename, 'w') as fp:
fp.write('[{"x": "foo"},{"x": "bar"}]')
reader = JsonReader(filename, ioformat=settings.IOFORMAT_ARG0)
context = CapturingNodeExecutionContext(reader, services={'fs': fs}) with CapturingNodeExecutionContext(
JsonReader(filename, ioformat=settings.IOFORMAT_ARG0),
context.start() services=services,
) as context:
context.write(BEGIN, Bag(), END) context.write(BEGIN, Bag(), END)
context.step() context.step()
context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2

View File

@ -2,24 +2,22 @@ import pickle
import pytest import pytest
from bonobo import Bag, PickleReader, PickleWriter, open_fs, settings from bonobo import Bag, PickleReader, PickleWriter, settings
from bonobo.constants import BEGIN, END from bonobo.constants import BEGIN, END
from bonobo.execution.node import NodeExecutionContext from bonobo.execution.node import NodeExecutionContext
from bonobo.util.testing import CapturingNodeExecutionContext from bonobo.util.testing import CapturingNodeExecutionContext, FilesystemTester
pickle_tester = FilesystemTester('pkl', mode='wb')
pickle_tester.input_data = pickle.dumps([['a', 'b', 'c'], ['a foo', 'b foo', 'c foo'], ['a bar', 'b bar', 'c bar']])
def test_write_pickled_dict_to_file(tmpdir): def test_write_pickled_dict_to_file(tmpdir):
fs, filename = open_fs(tmpdir), 'output.pkl' fs, filename, services = pickle_tester.get_services_for_writer(tmpdir)
writer = PickleWriter(filename, ioformat=settings.IOFORMAT_ARG0)
context = NodeExecutionContext(writer, services={'fs': fs})
with NodeExecutionContext(PickleWriter(filename, ioformat=settings.IOFORMAT_ARG0), services=services) as context:
context.write(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END) context.write(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END)
context.start()
context.step() context.step()
context.step() context.step()
context.stop()
with fs.open(filename, 'rb') as fp: with fs.open(filename, 'rb') as fp:
assert pickle.loads(fp.read()) == {'foo': 'bar'} assert pickle.loads(fp.read()) == {'foo': 'bar'}
@ -29,18 +27,13 @@ def test_write_pickled_dict_to_file(tmpdir):
def test_read_pickled_list_from_file(tmpdir): def test_read_pickled_list_from_file(tmpdir):
fs, filename = open_fs(tmpdir), 'input.pkl' fs, filename, services = pickle_tester.get_services_for_reader(tmpdir)
with fs.open(filename, 'wb') as fp:
fp.write(pickle.dumps([['a', 'b', 'c'], ['a foo', 'b foo', 'c foo'], ['a bar', 'b bar', 'c bar']]))
reader = PickleReader(filename, ioformat=settings.IOFORMAT_ARG0) with CapturingNodeExecutionContext(
PickleReader(filename, ioformat=settings.IOFORMAT_ARG0), services=services
context = CapturingNodeExecutionContext(reader, services={'fs': fs}) ) as context:
context.start()
context.write(BEGIN, Bag(), END) context.write(BEGIN, Bag(), END)
context.step() context.step()
context.stop()
assert len(context.send.mock_calls) == 2 assert len(context.send.mock_calls) == 2