diff --git a/bonobo/commands/run.py b/bonobo/commands/run.py index 4757a54..59190be 100644 --- a/bonobo/commands/run.py +++ b/bonobo/commands/run.py @@ -1,8 +1,10 @@ import codecs import os +from pathlib import Path import bonobo from bonobo.constants import DEFAULT_SERVICES_ATTR, DEFAULT_SERVICES_FILENAME +from dotenv import load_dotenv DEFAULT_GRAPH_FILENAMES = ('__main__.py', 'main.py', ) DEFAULT_GRAPH_ATTR = 'get_graph' @@ -41,7 +43,7 @@ def _install_requirements(requirements): importlib.reload(site) -def read(filename, module, install=False, quiet=False, verbose=False, env=None): +def read(filename, module, install=False, quiet=False, verbose=False, default_env_file=None, default_env=None, env_file=None, env=None): import runpy from bonobo import Graph, settings @@ -52,21 +54,6 @@ def read(filename, module, install=False, quiet=False, verbose=False, env=None): if verbose: settings.DEBUG.set(True) - if env: - __escape_decoder = codecs.getdecoder('unicode_escape') - - def decode_escaped(escaped): - return __escape_decoder(escaped)[0] - - for e in env: - ename, evalue = e.split('=', 1) - - if len(evalue) > 0: - if evalue[0] == evalue[len(evalue) - 1] in ['"', "'"]: - evalue = decode_escaped(evalue[1:-1]) - - os.environ[ename] = evalue - if filename: if os.path.isdir(filename): if install: @@ -84,12 +71,38 @@ def read(filename, module, install=False, quiet=False, verbose=False, env=None): requirements = os.path.join(os.path.dirname(filename), 'requirements.txt') _install_requirements(requirements) context = runpy.run_path(filename, run_name='__bonobo__') + env_dir = Path(filename).parent elif module: context = runpy.run_module(module, run_name='__bonobo__') filename = context['__file__'] + env_dir = Path(module) else: raise RuntimeError('UNEXPECTED: argparse should not allow this.') + if default_env_file: + for f in default_env_file: + env_file_path = env_dir.joinpath(f) + load_dotenv(env_file_path) + else: + try: + env_file_path = env_dir.joinpath('.env') + load_dotenv(env_file_path) + except FileNotFoundError: + pass + + if default_env: + for e in default_env: + set_env_var(e) + + if env_file: + for f in env_file: + env_file_path = env_dir.joinpath(f) + load_dotenv(env_file_path, override=True) + + if env: + for e in env: + set_env_var(e, override=True) + graphs = dict((k, v) for k, v in context.items() if isinstance(v, Graph)) assert len(graphs) == 1, ( @@ -106,8 +119,25 @@ def read(filename, module, install=False, quiet=False, verbose=False, env=None): return graph, plugins, services -def execute(filename, module, install=False, quiet=False, verbose=False, env=None): - graph, plugins, services = read(filename, module, install, quiet, verbose, env) +def set_env_var(e, override=False): + __escape_decoder = codecs.getdecoder('unicode_escape') + ename, evalue = e.split('=', 1) + + def decode_escaped(escaped): + return __escape_decoder(escaped)[0] + + if len(evalue) > 0: + if evalue[0] == evalue[len(evalue) - 1] in ['"', "'"]: + evalue = decode_escaped(evalue[1:-1]) + + if override: + os.environ[ename] = evalue + else: + os.environ.setdefault(ename, evalue) + + +def execute(filename, module, install=False, quiet=False, verbose=False, default_env_file=None, default_env=None, env_file=None, env=None): + graph, plugins, services = read(filename, module, install, quiet, verbose, default_env_file, default_env, env_file, env) return bonobo.run(graph, plugins=plugins, services=services) @@ -116,6 +146,9 @@ def register_generic_run_arguments(parser, required=True): source_group = parser.add_mutually_exclusive_group(required=required) source_group.add_argument('filename', nargs='?', type=str) source_group.add_argument('--module', '-m', type=str) + parser.add_argument('--default-env-file', action='append') + parser.add_argument('--default-env', action='append') + parser.add_argument('--env-file', action='append') parser.add_argument('--env', '-e', action='append') return parser diff --git a/bonobo/examples/env_vars/get_passed_env_file.py b/bonobo/examples/env_vars/get_passed_env_file.py new file mode 100644 index 0000000..e7a0952 --- /dev/null +++ b/bonobo/examples/env_vars/get_passed_env_file.py @@ -0,0 +1,22 @@ +import os + +import bonobo + + +def extract(): + my_secret = os.getenv('MY_SECRET') + test_user_password = os.getenv('TEST_USER_PASSWORD') + user = os.getenv('USERNAME') + path = os.getenv('PATH') + + return my_secret, test_user_password, user, path + + +def load(s: str): + print(s) + + +graph = bonobo.Graph(extract, load) + +if __name__ == '__main__': + bonobo.run(graph) diff --git a/tests/test_commands.py b/tests/test_commands.py index a29465c..ce7b582 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -98,6 +98,33 @@ def test_version(runner, capsys): assert __version__ in out +@all_runners +def test_run_module_with_default_env_file(runner, capsys): + runner( + 'run', '--quiet', get_examples_path('env_vars/get_passed_env_file.py') + ) + out, err = capsys.readouterr() + out = out.split('\n') + assert out[0] == '321' + assert out[1] == 'sweetpassword' + assert out[2] != 'not_cwandrews_123' + assert out[3] != 'marzo' + + +# @all_runners +# def test_run_with_env_file(runner, capsys): +# runner( +# 'run', '--quiet', +# get_examples_path('env_vars/get_passed_env.py'), '--env', 'ENV_TEST_NUMBER=123', '--env', +# 'ENV_TEST_USER=cwandrews', '--env', "ENV_TEST_STRING='my_test_string'" +# ) +# out, err = capsys.readouterr() +# out = out.split('\n') +# assert out[0] == 'cwandrews' +# assert out[1] == '123' +# assert out[2] == 'my_test_string' + + @all_runners def test_run_with_env(runner, capsys): runner(