From c5007dbdb88c406e9d22f43dc6f0aa4cf23e0f2e Mon Sep 17 00:00:00 2001 From: Agampreet Singh Date: Thu, 15 Aug 2024 23:58:57 +0530 Subject: [PATCH] Remove required_ from core & added tests to ensure both are working --- src/tirith/core/core.py | 9 ++++--- tests/core/test_core.py | 55 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/tirith/core/core.py b/src/tirith/core/core.py index 6f002056..4250d507 100644 --- a/src/tirith/core/core.py +++ b/src/tirith/core/core.py @@ -225,14 +225,14 @@ def start_policy_evaluation_from_dict(policy_dict: Dict, input_dict: Dict) -> Di policy_meta = policy_dict.get("meta") eval_objects = policy_dict.get("evaluators") final_evaluation_policy_string = policy_dict.get("eval_expression") - provider_module = policy_meta.get("required_provider", "core") - # TODO: Write functionality for dynamically importing evaluators from other modules. + + provider_module = policy_meta.get("provider", policy_meta.get("required_provider", "core")) + eval_results = [] eval_results_obj = {} for eval_obj in eval_objects: eval_id = eval_obj.get("id") eval_description = eval_obj.get("description") - logger.debug(f"Processing evaluator '{eval_id}'") eval_result = generate_evaluator_result(eval_obj, input_dict, provider_module) eval_result["id"] = eval_id eval_result["description"] = eval_description @@ -241,10 +241,11 @@ def start_policy_evaluation_from_dict(policy_dict: Dict, input_dict: Dict) -> Di final_evaluation_result, errors = final_evaluator(final_evaluation_policy_string, eval_results_obj) final_output = { - "meta": {"version": policy_meta.get("version"), "required_provider": provider_module}, + "meta": {"version": policy_meta.get("version"), "provider": provider_module}, "final_result": final_evaluation_result, "evaluators": eval_results, "errors": errors, "eval_expression": final_evaluation_policy_string, } return final_output + diff --git a/tests/core/test_core.py b/tests/core/test_core.py index fa793e31..c1d32b6e 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -2,7 +2,7 @@ from pytest import mark from tirith.core.core import final_evaluator - +from tirith.core.core import start_policy_evaluation_from_dict @mark.passing def test_final_evaluator_skipped_check_should_be_removed(): @@ -38,3 +38,56 @@ def test_final_evaluator_malicious_eval_should_err(): "!skipped_check && passing_check || [].__class__.__base__", dict(skipped_check=None, passing_check=True) ) assert actual_result == (False, ["The following symbols are not allowed: __class__, __base__"]) + + +@mark.passing +def test_start_policy_evaluation_with_required_provider(): + policy_dict = { + "meta": {"version": "1.0", "required_provider": "legacy_provider"}, + "evaluators": [], + "eval_expression": "True", + } + input_dict = {} + + result = start_policy_evaluation_from_dict(policy_dict, input_dict) + + assert result["meta"]["provider"] == "legacy_provider" + +@mark.passing +def test_start_policy_evaluation_with_provider(): + policy_dict = { + "meta": {"version": "1.0", "provider": "new_provider"}, + "evaluators": [], + "eval_expression": "True", + } + input_dict = {} + + result = start_policy_evaluation_from_dict(policy_dict, input_dict) + + assert result["meta"]["provider"] == "new_provider" + +@mark.passing +def test_start_policy_evaluation_with_both_providers(): + policy_dict = { + "meta": {"version": "1.0", "provider": "new_provider", "required_provider": "legacy_provider"}, + "evaluators": [], + "eval_expression": "True", + } + input_dict = {} + + result = start_policy_evaluation_from_dict(policy_dict, input_dict) + + assert result["meta"]["provider"] == "new_provider" + +@mark.passing +def test_start_policy_evaluation_with_neither_provider(): + policy_dict = { + "meta": {"version": "1.0"}, + "evaluators": [], + "eval_expression": "True", + } + input_dict = {} + + result = start_policy_evaluation_from_dict(policy_dict, input_dict) + + assert result["meta"]["provider"] == "core" \ No newline at end of file