diff --git a/bonobo/commands/init.py b/bonobo/commands/init.py index 948f747..e69156c 100644 --- a/bonobo/commands/init.py +++ b/bonobo/commands/init.py @@ -1,3 +1,5 @@ +import os + def execute(name, branch): try: from cookiecutter.main import cookiecutter @@ -6,11 +8,17 @@ def execute(name, branch): 'You must install "cookiecutter" to use this command.\n\n $ pip install cookiecutter\n' ) from exc + overwrite_if_exists = False + project_path = os.path.join(os.getcwd(), name) + if os.path.isdir(project_path) and not os.listdir(project_path): + overwrite_if_exists = True + return cookiecutter( 'https://github.com/python-bonobo/cookiecutter-bonobo.git', extra_context={'name': name}, no_input=True, - checkout=branch + checkout=branch, + overwrite_if_exists=overwrite_if_exists ) diff --git a/tests/test_commands.py b/tests/test_commands.py index e467bb3..a96634c 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -8,18 +8,24 @@ from unittest.mock import patch import pkg_resources import pytest +from cookiecutter.exceptions import OutputDirExistsException from bonobo import __main__, __version__, get_examples_path from bonobo.commands import entrypoint +from bonobo.commands.run import DEFAULT_GRAPH_FILENAMES def runner(f): @functools.wraps(f) - def wrapped_runner(*args): + def wrapped_runner(*args, catch_errors=False): with redirect_stdout(io.StringIO()) as stdout, redirect_stderr(io.StringIO()) as stderr: try: f(list(args)) except BaseException as exc: + if not catch_errors: + raise + elif isinstance(catch_errors, BaseException) and not isinstance(exc, catch_errors): + raise return stdout.getvalue(), stderr.getvalue(), exc return stdout.getvalue(), stderr.getvalue() @@ -40,6 +46,7 @@ def runner_module(args): all_runners = pytest.mark.parametrize('runner', [runner_entrypoint, runner_module]) +single_runner = pytest.mark.parametrize('runner', [runner_module]) def test_entrypoint(): @@ -59,11 +66,41 @@ def test_entrypoint(): @all_runners def test_no_command(runner): - _, err, exc = runner() + _, err, exc = runner(catch_errors=True) assert type(exc) == SystemExit assert 'error: the following arguments are required: command' in err +@all_runners +def test_init(runner, tmpdir): + name = 'project' + tmpdir.chdir() + runner('init', name) + assert os.path.isdir(name) + assert set(os.listdir(name)) & set(DEFAULT_GRAPH_FILENAMES) + +@single_runner +def test_init_in_empty_then_nonempty_directory(runner, tmpdir): + name = 'project' + tmpdir.chdir() + os.mkdir(name) + + # run in empty dir + runner('init', name) + assert set(os.listdir(name)) & set(DEFAULT_GRAPH_FILENAMES) + + # run in non empty dir + with pytest.raises(OutputDirExistsException): + runner('init', name) + + +@single_runner +def test_init_within_empty_directory(runner, tmpdir): + tmpdir.chdir() + runner('init', '.') + assert set(os.listdir()) & set(DEFAULT_GRAPH_FILENAMES) + + @all_runners def test_run(runner): out, err = runner('run', '--quiet', get_examples_path('types/strings.py'))