Merge pull request #82 from hartym/feature/io_pickle

Feature/io pickle
This commit is contained in:
Romain Dorgueil
2017-05-25 07:47:58 -07:00
committed by GitHub
6 changed files with 193 additions and 2 deletions

View File

@ -1,6 +1,6 @@
from bonobo.structs import Bag, Graph, Token from bonobo.structs import Bag, Graph, Token
from bonobo.nodes import CsvReader, CsvWriter, FileReader, FileWriter, Filter, JsonReader, JsonWriter, Limit, \ from bonobo.nodes import CsvReader, CsvWriter, FileReader, FileWriter, Filter, JsonReader, JsonWriter, Limit, \
PrettyPrint, Tee, count, identity, noop, pprint PrettyPrint, PickleWriter, PickleReader, Tee, count, identity, noop, pprint
from bonobo.strategies import create_strategy from bonobo.strategies import create_strategy
from bonobo.util.objects import get_name from bonobo.util.objects import get_name
@ -43,7 +43,6 @@ def run(graph, strategy=None, plugins=None, services=None):
plugins = plugins or [] plugins = plugins or []
from bonobo import settings from bonobo import settings
settings.check() settings.check()
if not settings.QUIET: # pragma: no cover if not settings.QUIET: # pragma: no cover
@ -98,6 +97,8 @@ register_api_group(
JsonReader, JsonReader,
JsonWriter, JsonWriter,
Limit, Limit,
PickleReader,
PickleWriter,
PrettyPrint, PrettyPrint,
Tee, Tee,
count, count,

Binary file not shown.

View File

@ -0,0 +1,58 @@
import bonobo
from fs.tarfs import TarFS
import os
def cleanse_sms(row):
if row['category'] == 'spam':
row['sms_clean'] = '**MARKED AS SPAM** ' + row['sms'][0:50] + ('...' if len(row['sms']) > 50 else '')
else:
row['sms_clean'] = row['sms']
return row['sms_clean']
graph = bonobo.Graph(
bonobo.PickleReader('spam.pkl'), # spam.pkl is within the gzipped tarball
cleanse_sms,
print
)
if __name__ == '__main__':
'''
This example shows how a different file system service can be injected
into a transformation (as compressing pickled objects often makes sense
anyways). The pickle itself contains a list of lists as follows:
```
[
['category', 'sms'],
['ham', 'Go until jurong point, crazy..'],
['ham', 'Ok lar... Joking wif u oni...'],
['spam', 'Free entry in 2 a wkly comp to win...'],
['ham', 'U dun say so early hor... U c already then say...'],
['ham', 'Nah I don't think he goes to usf, he lives around here though'],
['spam', 'FreeMsg Hey there darling it's been 3 week's now...'],
...
]
```
where the first column categorizes and sms as "ham" or "spam". The second
column contains the sms itself.
Data set taken from:
https://www.kaggle.com/uciml/sms-spam-collection-dataset/downloads/sms-spam-collection-dataset.zip
The transformation (1) reads the pickled data, (2) marks and shortens
messages categorized as spam, and (3) prints the output.
'''
services = {
'fs': TarFS(
os.path.join(bonobo.get_examples_path(), 'datasets', 'spam.tgz')
)
}
bonobo.run(graph, services=services)

View File

@ -3,6 +3,7 @@
from .file import FileReader, FileWriter from .file import FileReader, FileWriter
from .json import JsonReader, JsonWriter from .json import JsonReader, JsonWriter
from .csv import CsvReader, CsvWriter from .csv import CsvReader, CsvWriter
from .pickle import PickleReader, PickleWriter
__all__ = [ __all__ = [
'CsvReader', 'CsvReader',
@ -11,4 +12,6 @@ __all__ = [
'FileWriter', 'FileWriter',
'JsonReader', 'JsonReader',
'JsonWriter', 'JsonWriter',
'PickleReader',
'PickleWriter',
] ]

69
bonobo/nodes/io/pickle.py Normal file
View File

@ -0,0 +1,69 @@
import pickle
from bonobo.config.processors import ContextProcessor
from bonobo.config import Option
from bonobo.constants import NOT_MODIFIED
from bonobo.util.objects import ValueHolder
from .file import FileReader, FileWriter, FileHandler
class PickleHandler(FileHandler):
"""
.. attribute:: item_names
The names of the items in the pickle, if it is not defined in the first item of the pickle.
"""
item_names = Option(tuple)
class PickleReader(PickleHandler, FileReader):
"""
Reads a Python pickle object and yields the items in dicts.
"""
mode = Option(str, default='rb')
@ContextProcessor
def pickle_headers(self, context, fs, file):
yield ValueHolder(self.item_names)
def read(self, fs, file, pickle_headers):
data = pickle.load(file)
# if the data is not iterable, then wrap the object in a list so it may be iterated
if isinstance(data, dict):
is_dict = True
iterator = iter(data.items())
else:
is_dict = False
try:
iterator = iter(data)
except TypeError:
iterator = iter([data])
if not pickle_headers.get():
pickle_headers.set(next(iterator))
item_count = len(pickle_headers.value)
for i in iterator:
if len(i) != item_count:
raise ValueError('Received an object with %d items, expecting %d.' % (len(i), item_count, ))
yield dict(zip(i)) if is_dict else dict(zip(pickle_headers.value, i))
class PickleWriter(PickleHandler, FileWriter):
mode = Option(str, default='wb')
def write(self, fs, file, lineno, item):
"""
Write a pickled item to the opened file.
"""
file.write(pickle.dumps(item))
lineno += 1
return NOT_MODIFIED

60
tests/io/test_pickle.py Normal file
View File

@ -0,0 +1,60 @@
import pickle
import pytest
from bonobo import Bag, PickleReader, PickleWriter, open_fs
from bonobo.constants import BEGIN, END
from bonobo.execution.node import NodeExecutionContext
from bonobo.util.testing import CapturingNodeExecutionContext
def test_write_pickled_dict_to_file(tmpdir):
fs, filename = open_fs(tmpdir), 'output.pkl'
writer = PickleWriter(path=filename)
context = NodeExecutionContext(writer, services={'fs': fs})
context.write(BEGIN, Bag({'foo': 'bar'}), Bag({'foo': 'baz', 'ignore': 'this'}), END)
context.start()
context.step()
context.step()
context.stop()
assert pickle.loads(fs.open(filename, 'rb').read()) == {'foo': 'bar'}
with pytest.raises(AttributeError):
getattr(context, 'file')
def test_read_pickled_list_from_file(tmpdir):
fs, filename = open_fs(tmpdir), 'input.pkl'
fs.open(filename, 'wb').write(pickle.dumps([
['a', 'b', 'c'], ['a foo', 'b foo', 'c foo'], ['a bar', 'b bar', 'c bar']
]))
reader = PickleReader(path=filename)
context = CapturingNodeExecutionContext(reader, services={'fs': fs})
context.start()
context.write(BEGIN, Bag(), END)
context.step()
context.stop()
assert len(context.send.mock_calls) == 2
args0, kwargs0 = context.send.call_args_list[0]
assert len(args0) == 1 and not len(kwargs0)
args1, kwargs1 = context.send.call_args_list[1]
assert len(args1) == 1 and not len(kwargs1)
assert args0[0].args[0] == {
'a': 'a foo',
'b': 'b foo',
'c': 'c foo',
}
assert args1[0].args[0] == {
'a': 'a bar',
'b': 'b bar',
'c': 'c bar',
}