diff --git a/bonobo/commands/init.py b/bonobo/commands/init.py index 948f747..e69156c 100644 --- a/bonobo/commands/init.py +++ b/bonobo/commands/init.py @@ -1,3 +1,5 @@ +import os + def execute(name, branch): try: from cookiecutter.main import cookiecutter @@ -6,11 +8,17 @@ def execute(name, branch): 'You must install "cookiecutter" to use this command.\n\n $ pip install cookiecutter\n' ) from exc + overwrite_if_exists = False + project_path = os.path.join(os.getcwd(), name) + if os.path.isdir(project_path) and not os.listdir(project_path): + overwrite_if_exists = True + return cookiecutter( 'https://github.com/python-bonobo/cookiecutter-bonobo.git', extra_context={'name': name}, no_input=True, - checkout=branch + checkout=branch, + overwrite_if_exists=overwrite_if_exists ) diff --git a/bonobo/examples/tutorials/tut02e02_write.py b/bonobo/examples/tutorials/tut02e02_write.py index e5a8445..c4b065d 100644 --- a/bonobo/examples/tutorials/tut02e02_write.py +++ b/bonobo/examples/tutorials/tut02e02_write.py @@ -2,14 +2,14 @@ import bonobo def split_one(line): - return line.split(', ', 1) + return dict(zip(("name", "address"), line.split(', ', 1))) graph = bonobo.Graph( bonobo.FileReader('coffeeshops.txt'), split_one, bonobo.JsonWriter( - 'coffeeshops.json', fs='fs.output', ioformat='arg0' + 'coffeeshops.json', fs='fs.output' ), ) diff --git a/bonobo/examples/tutorials/tut02e03_writeasmap.py b/bonobo/examples/tutorials/tut02e03_writeasmap.py index e234f22..c7c7711 100644 --- a/bonobo/examples/tutorials/tut02e03_writeasmap.py +++ b/bonobo/examples/tutorials/tut02e03_writeasmap.py @@ -11,7 +11,7 @@ def split_one_to_map(line): class MyJsonWriter(bonobo.JsonWriter): prefix, suffix = '{', '}' - def write(self, fs, file, lineno, row): + def write(self, fs, file, lineno, **row): return bonobo.FileWriter.write( self, fs, file, lineno, json.dumps(row)[1:-1] ) @@ -20,7 +20,7 @@ class MyJsonWriter(bonobo.JsonWriter): graph = bonobo.Graph( bonobo.FileReader('coffeeshops.txt'), split_one_to_map, - MyJsonWriter('coffeeshops.json', fs='fs.output', ioformat='arg0'), + MyJsonWriter('coffeeshops.json', fs='fs.output'), ) diff --git a/docs/tutorial/tut01.rst b/docs/tutorial/tut01.rst index d6aa604..3d6f9eb 100644 --- a/docs/tutorial/tut01.rst +++ b/docs/tutorial/tut01.rst @@ -105,6 +105,9 @@ To do this, it needs to know what data-flow you want to achieve, and you'll use The `if __name__ == '__main__':` section is not required, unless you want to run it directly using the python interpreter. + The name of the `graph` variable is arbitrary, but this variable must be global and available unconditionally. + Do not put it in its own function or in the `if __name__ == '__main__':` section. + Execute the job ::::::::::::::: diff --git a/tests/test_commands.py b/tests/test_commands.py index e467bb3..a96634c 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -8,18 +8,24 @@ from unittest.mock import patch import pkg_resources import pytest +from cookiecutter.exceptions import OutputDirExistsException from bonobo import __main__, __version__, get_examples_path from bonobo.commands import entrypoint +from bonobo.commands.run import DEFAULT_GRAPH_FILENAMES def runner(f): @functools.wraps(f) - def wrapped_runner(*args): + def wrapped_runner(*args, catch_errors=False): with redirect_stdout(io.StringIO()) as stdout, redirect_stderr(io.StringIO()) as stderr: try: f(list(args)) except BaseException as exc: + if not catch_errors: + raise + elif isinstance(catch_errors, BaseException) and not isinstance(exc, catch_errors): + raise return stdout.getvalue(), stderr.getvalue(), exc return stdout.getvalue(), stderr.getvalue() @@ -40,6 +46,7 @@ def runner_module(args): all_runners = pytest.mark.parametrize('runner', [runner_entrypoint, runner_module]) +single_runner = pytest.mark.parametrize('runner', [runner_module]) def test_entrypoint(): @@ -59,11 +66,41 @@ def test_entrypoint(): @all_runners def test_no_command(runner): - _, err, exc = runner() + _, err, exc = runner(catch_errors=True) assert type(exc) == SystemExit assert 'error: the following arguments are required: command' in err +@all_runners +def test_init(runner, tmpdir): + name = 'project' + tmpdir.chdir() + runner('init', name) + assert os.path.isdir(name) + assert set(os.listdir(name)) & set(DEFAULT_GRAPH_FILENAMES) + +@single_runner +def test_init_in_empty_then_nonempty_directory(runner, tmpdir): + name = 'project' + tmpdir.chdir() + os.mkdir(name) + + # run in empty dir + runner('init', name) + assert set(os.listdir(name)) & set(DEFAULT_GRAPH_FILENAMES) + + # run in non empty dir + with pytest.raises(OutputDirExistsException): + runner('init', name) + + +@single_runner +def test_init_within_empty_directory(runner, tmpdir): + tmpdir.chdir() + runner('init', '.') + assert set(os.listdir()) & set(DEFAULT_GRAPH_FILENAMES) + + @all_runners def test_run(runner): out, err = runner('run', '--quiet', get_examples_path('types/strings.py'))