diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index a721d945a..7564c45b6 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -258,7 +258,7 @@ def find_injectable_params(nodes: Collection[node.Node]) -> Dict[str, Type[Type] def transform_dag( self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable - ) -> Collection[node.Node]: + ) -> List[node.Node]: """Transforms the subDAG by getting the injectable parameters (anything not produced by nodes inside it), then calling the inject_nodes function on it. diff --git a/hamilton/graph.py b/hamilton/graph.py index 93ee23efb..43ccd24ca 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -143,7 +143,7 @@ def update_dependencies( def create_function_graph( - *modules: ModuleType, + *functions: List[Tuple[str, Callable]], config: Dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet = None, fg: Optional["FunctionGraph"] = None, @@ -164,7 +164,6 @@ def create_function_graph( nodes = {} # name -> Node else: nodes = fg.nodes - functions = sum([find_functions(module) for module in modules], []) # create non-input nodes -- easier to just create this in one loop for _func_name, f in functions: @@ -733,8 +732,26 @@ def from_modules( :return: a function graph. """ + functions = sum([find_functions(module) for module in modules], []) + return FunctionGraph.from_functions( + *functions, + config=config, + adapter=adapter, + allow_module_overrides=allow_module_overrides, + ) + + @staticmethod + def from_functions( + *functions, + config: Dict[str, Any], + adapter: lifecycle_base.LifecycleAdapterSet = None, + allow_module_overrides: bool = False, + ) -> "FunctionGraph": nodes = create_function_graph( - *modules, config=config, adapter=adapter, allow_module_overrides=allow_module_overrides + *functions, + config=config, + adapter=adapter, + allow_module_overrides=allow_module_overrides, ) return FunctionGraph(nodes, config, adapter) diff --git a/hamilton/graph_utils.py b/hamilton/graph_utils.py index 8c3695b4d..91f77aea1 100644 --- a/hamilton/graph_utils.py +++ b/hamilton/graph_utils.py @@ -7,18 +7,25 @@ def is_submodule(child: ModuleType, parent: ModuleType): return parent.__name__ in child.__name__ +def is_hamilton_function(fn: Callable) -> bool: + """A `Hamilton function` defines a node. + To be valid, it must not start with the underscore `_` prefix. + """ + return inspect.isfunction(fn) and not fn.__name__.startswith("_") + + +# NOTE. This should return a list of callables instead of tuples. Internally, +# the function names are never used (except in the CLI) and that information +# is readily available through `fn.__name__`. +# Care is required because users may have been advised to use this code path. def find_functions(function_module: ModuleType) -> List[Tuple[str, Callable]]: """Function to determine the set of functions we want to build a graph from. This iterates through the function module and grabs all function definitions. :return: list of tuples of (func_name, function). """ - - def valid_fn(fn): - return ( - inspect.isfunction(fn) - and not fn.__name__.startswith("_") - and is_submodule(inspect.getmodule(fn), function_module) - ) - - return [f for f in inspect.getmembers(function_module, predicate=valid_fn)] + return [ + (name, fn) + for name, fn in inspect.getmembers(function_module) + if is_hamilton_function(fn) and is_submodule(inspect.getmodule(fn), function_module) + ] diff --git a/tests/execution/test_node_grouping.py b/tests/execution/test_node_grouping.py index 270159fc3..81296d51a 100644 --- a/tests/execution/test_node_grouping.py +++ b/tests/execution/test_node_grouping.py @@ -98,8 +98,10 @@ def bar(foo: int) -> int: # This is hacking around function graph which is messy as it is built of larger components # (modules), and should instead be broken into smaller pieces (functions/nodes), and have utilities # to create it from those. - fn_graph = graph.create_function_graph(ad_hoc_utils.create_temporary_module(bar), config={}) - node_ = fn_graph["bar"] + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(bar), config={} + ) + node_ = fn_graph.nodes["bar"] task = grouping.TaskSpec( base_id="bar", nodes=[node_], @@ -120,8 +122,10 @@ def bar(foo: int, baz: int = 1) -> int: # This is hacking around function graph which is messy as it is built of larger components # (modules), and should instead be broken into smaller pieces (functions/nodes), and have utilities # to create it from those. - fn_graph = graph.create_function_graph(ad_hoc_utils.create_temporary_module(bar), config={}) - node_ = fn_graph["bar"] + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(bar), config={} + ) + node_ = fn_graph.nodes["bar"] task = grouping.TaskSpec( base_id="bar", nodes=[node_], diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py index 06ba57b6b..02d0db0ca 100644 --- a/tests/function_modifiers/test_adapters.py +++ b/tests/function_modifiers/test_adapters.py @@ -570,12 +570,11 @@ def test_loader_default_factory_field(): def foo(param: int) -> int: return param - fg = graph.create_function_graph( - ad_hoc_utils.create_temporary_module(foo), - config={}, + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(foo), config={} ) - assert len(fg) == 3 - assert "foo" in fg + assert len(fn_graph.nodes) == 3 + assert "foo" in fn_graph.nodes @dataclasses.dataclass @@ -602,12 +601,11 @@ def test_saver_default_factory_field(): def foo(param: int) -> int: return param - fg = graph.create_function_graph( - ad_hoc_utils.create_temporary_module(foo), - config={}, + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(foo), config={} ) - assert len(fg) == 3 - assert "foo" in fg + assert len(fn_graph.nodes) == 3 + assert "foo" in fn_graph.nodes @dataclasses.dataclass @@ -631,12 +629,11 @@ def test_adapters_optional_params(): def foo(param: int) -> int: return param - fg = graph.create_function_graph( - ad_hoc_utils.create_temporary_module(foo), - config={}, + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(foo), config={} ) - assert len(fg) == 3 - assert "foo" in fg + assert len(fn_graph.nodes) == 3 + assert "foo" in fn_graph.nodes def test_save_to_with_input_from_other_fn(): @@ -648,12 +645,11 @@ def output_path() -> str: def fn() -> dict: return {"a": 1} - fg = graph.create_function_graph( - ad_hoc_utils.create_temporary_module(output_path, fn), - config={}, + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(output_path, fn), config={} ) - assert len(fg) == 3 + assert len(fn_graph.nodes) == 3 def test_load_from_with_input_from_other_fn(): @@ -665,11 +661,10 @@ def input_path() -> str: def fn(data: dict) -> dict: return data - fg = graph.create_function_graph( - ad_hoc_utils.create_temporary_module(input_path, fn), - config={}, + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(input_path, fn), config={} ) - assert len(fg) == 4 + assert len(fn_graph.nodes) == 4 def test_load_from_with_multiple_inputs(): @@ -686,12 +681,9 @@ def test_load_from_with_multiple_inputs(): def fn(data1: dict, data2: dict) -> dict: return {**data1, **data2} - fg = graph.create_function_graph( - ad_hoc_utils.create_temporary_module(fn), - config={}, - ) + fn_graph = graph.FunctionGraph.from_modules(ad_hoc_utils.create_temporary_module(fn), config={}) # One filter, one loader for each and the transform function - assert len(fg) == 5 + assert len(fn_graph.nodes) == 5 import sys @@ -819,12 +811,11 @@ def test_dataloader_future_annotations(): from tests.resources import nodes_with_future_annotation fn_to_collect = nodes_with_future_annotation.sample_dataloader - fg = graph.create_function_graph( - ad_hoc_utils.create_temporary_module(fn_to_collect), - config={}, + fn_graph = graph.FunctionGraph.from_modules( + ad_hoc_utils.create_temporary_module(fn_to_collect), config={} ) # the data loaded is a list - assert custom_subclass_check(fg["sample_dataloader"].type, list) + assert custom_subclass_check(fn_graph.nodes["sample_dataloader"].type, list) def test_datasaver(): diff --git a/tests/test_graph.py b/tests/test_graph.py index a311fbaa9..f093fab7c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -470,20 +470,20 @@ def create_testing_nodes_override_B(): def test_create_function_graph_simple(): """Tests that we create a simple function graph.""" expected = create_testing_nodes() - actual = graph.create_function_graph(tests.resources.dummy_functions, config={}) - assert actual == expected + fn_graph = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={}) + assert fn_graph.nodes == expected def test_create_function_graph_with_override(): """Tests that we can override nodes from later modules in function graph.""" override_expected = create_testing_nodes_override_B() - override_actual = graph.create_function_graph( + fn_graph = graph.FunctionGraph.from_modules( tests.resources.dummy_functions, tests.resources.dummy_functions_module_override, config={}, allow_module_overrides=True, ) - assert override_expected == override_actual + assert fn_graph.nodes == override_expected def test_execute():