Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def find_injectable_params(nodes: Collection[node.Node]) -> Dict[str, Type[Type]

def transform_dag(
self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
) -> List[node.Node]:
"""Transforms the subDAG by getting the injectable parameters (anything not
produced by nodes inside it), then calling the inject_nodes function on it.

Expand Down
23 changes: 20 additions & 3 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def update_dependencies(


def create_function_graph(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type signature of this should be updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also 🤦 (at myself I think) that the name of this function and what it returns - a dict, not a function graph...

*modules: ModuleType,
*functions: List[Tuple[str, Callable]],
config: Dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
fg: Optional["FunctionGraph"] = None,
Expand All @@ -164,7 +164,6 @@ def create_function_graph(
nodes = {} # name -> Node
else:
nodes = fg.nodes
functions = sum([find_functions(module) for module in modules], [])

# create non-input nodes -- easier to just create this in one loop
for _func_name, f in functions:
Expand Down Expand Up @@ -733,8 +732,26 @@ def from_modules(
:return: a function graph.
"""

functions = sum([find_functions(module) for module in modules], [])
return FunctionGraph.from_functions(
*functions,
config=config,
adapter=adapter,
allow_module_overrides=allow_module_overrides,
)

@staticmethod
def from_functions(
*functions,
config: Dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
allow_module_overrides: bool = False,
) -> "FunctionGraph":
nodes = create_function_graph(
*modules, config=config, adapter=adapter, allow_module_overrides=allow_module_overrides
*functions,
config=config,
adapter=adapter,
allow_module_overrides=allow_module_overrides,
)
return FunctionGraph(nodes, config, adapter)

Expand Down
25 changes: 16 additions & 9 deletions hamilton/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,25 @@ def is_submodule(child: ModuleType, parent: ModuleType):
return parent.__name__ in child.__name__


def is_hamilton_function(fn: Callable) -> bool:
"""A `Hamilton function` defines a node.
To be valid, it must not start with the underscore `_` prefix.
"""
return inspect.isfunction(fn) and not fn.__name__.startswith("_")


# NOTE. This should return a list of callables instead of tuples. Internally,
# the function names are never used (except in the CLI) and that information
# is readily available through `fn.__name__`.
# Care is required because users may have been advised to use this code path.
def find_functions(function_module: ModuleType) -> List[Tuple[str, Callable]]:
"""Function to determine the set of functions we want to build a graph from.

This iterates through the function module and grabs all function definitions.
:return: list of tuples of (func_name, function).
"""

def valid_fn(fn):
return (
inspect.isfunction(fn)
and not fn.__name__.startswith("_")
and is_submodule(inspect.getmodule(fn), function_module)
)

return [f for f in inspect.getmembers(function_module, predicate=valid_fn)]
return [
(name, fn)
for name, fn in inspect.getmembers(function_module)
if is_hamilton_function(fn) and is_submodule(inspect.getmodule(fn), function_module)
]
12 changes: 8 additions & 4 deletions tests/execution/test_node_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def bar(foo: int) -> int:
# This is hacking around function graph which is messy as it is built of larger components
# (modules), and should instead be broken into smaller pieces (functions/nodes), and have utilities
# to create it from those.
fn_graph = graph.create_function_graph(ad_hoc_utils.create_temporary_module(bar), config={})
node_ = fn_graph["bar"]
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(bar), config={}
)
node_ = fn_graph.nodes["bar"]
task = grouping.TaskSpec(
base_id="bar",
nodes=[node_],
Expand All @@ -120,8 +122,10 @@ def bar(foo: int, baz: int = 1) -> int:
# This is hacking around function graph which is messy as it is built of larger components
# (modules), and should instead be broken into smaller pieces (functions/nodes), and have utilities
# to create it from those.
fn_graph = graph.create_function_graph(ad_hoc_utils.create_temporary_module(bar), config={})
node_ = fn_graph["bar"]
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(bar), config={}
)
node_ = fn_graph.nodes["bar"]
task = grouping.TaskSpec(
base_id="bar",
nodes=[node_],
Expand Down
55 changes: 23 additions & 32 deletions tests/function_modifiers/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,11 @@ def test_loader_default_factory_field():
def foo(param: int) -> int:
return param

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(foo),
config={},
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(foo), config={}
)
assert len(fg) == 3
assert "foo" in fg
assert len(fn_graph.nodes) == 3
assert "foo" in fn_graph.nodes


@dataclasses.dataclass
Expand All @@ -602,12 +601,11 @@ def test_saver_default_factory_field():
def foo(param: int) -> int:
return param

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(foo),
config={},
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(foo), config={}
)
assert len(fg) == 3
assert "foo" in fg
assert len(fn_graph.nodes) == 3
assert "foo" in fn_graph.nodes


@dataclasses.dataclass
Expand All @@ -631,12 +629,11 @@ def test_adapters_optional_params():
def foo(param: int) -> int:
return param

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(foo),
config={},
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(foo), config={}
)
assert len(fg) == 3
assert "foo" in fg
assert len(fn_graph.nodes) == 3
assert "foo" in fn_graph.nodes


def test_save_to_with_input_from_other_fn():
Expand All @@ -648,12 +645,11 @@ def output_path() -> str:
def fn() -> dict:
return {"a": 1}

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(output_path, fn),
config={},
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(output_path, fn), config={}
)

assert len(fg) == 3
assert len(fn_graph.nodes) == 3


def test_load_from_with_input_from_other_fn():
Expand All @@ -665,11 +661,10 @@ def input_path() -> str:
def fn(data: dict) -> dict:
return data

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(input_path, fn),
config={},
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(input_path, fn), config={}
)
assert len(fg) == 4
assert len(fn_graph.nodes) == 4


def test_load_from_with_multiple_inputs():
Expand All @@ -686,12 +681,9 @@ def test_load_from_with_multiple_inputs():
def fn(data1: dict, data2: dict) -> dict:
return {**data1, **data2}

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(fn),
config={},
)
fn_graph = graph.FunctionGraph.from_modules(ad_hoc_utils.create_temporary_module(fn), config={})
# One filter, one loader for each and the transform function
assert len(fg) == 5
assert len(fn_graph.nodes) == 5


import sys
Expand Down Expand Up @@ -819,12 +811,11 @@ def test_dataloader_future_annotations():
from tests.resources import nodes_with_future_annotation

fn_to_collect = nodes_with_future_annotation.sample_dataloader
fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(fn_to_collect),
config={},
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(fn_to_collect), config={}
)
# the data loaded is a list
assert custom_subclass_check(fg["sample_dataloader"].type, list)
assert custom_subclass_check(fn_graph.nodes["sample_dataloader"].type, list)


def test_datasaver():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,20 +470,20 @@ def create_testing_nodes_override_B():
def test_create_function_graph_simple():
"""Tests that we create a simple function graph."""
expected = create_testing_nodes()
actual = graph.create_function_graph(tests.resources.dummy_functions, config={})
assert actual == expected
fn_graph = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
assert fn_graph.nodes == expected


def test_create_function_graph_with_override():
"""Tests that we can override nodes from later modules in function graph."""
override_expected = create_testing_nodes_override_B()
override_actual = graph.create_function_graph(
fn_graph = graph.FunctionGraph.from_modules(
tests.resources.dummy_functions,
tests.resources.dummy_functions_module_override,
config={},
allow_module_overrides=True,
)
assert override_expected == override_actual
assert fn_graph.nodes == override_expected


def test_execute():
Expand Down