Skip to content

Commit 4519df3

Browse files
committed
Improve typing in pytensor/compile/builders.py
1 parent f753e06 commit 4519df3

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

pytensor/compile/builders.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Define new Ops from existing Ops"""
22

33
import warnings
4-
from collections.abc import Callable, Sequence
4+
from collections.abc import Callable
55
from copy import copy
66
from functools import partial
77
from itertools import chain
8-
from typing import Union, cast
8+
from typing import Union
99

1010
from pytensor.compile.function import function
1111
from pytensor.compile.function.pfunc import rebuild_collect_shared
@@ -88,12 +88,12 @@ def local_traverse(out):
8888

8989

9090
def construct_nominal_fgraph(
91-
inputs: Sequence[Variable], outputs: Sequence[Variable]
91+
inputs: list[Variable], outputs: list[Variable]
9292
) -> tuple[
9393
FunctionGraph,
94-
Sequence[Variable],
95-
dict[Variable, Variable],
96-
dict[Variable, Variable],
94+
list[SharedVariable],
95+
dict[SharedVariable, Variable],
96+
list[Variable],
9797
]:
9898
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
9999
implicit_shared_inputs = []
@@ -119,7 +119,7 @@ def construct_nominal_fgraph(
119119
)
120120

121121
new = rebuild_collect_shared(
122-
cast(Sequence[Variable], outputs),
122+
outputs,
123123
inputs=inputs + implicit_shared_inputs,
124124
replace=replacements,
125125
copy_inputs_over=False,
@@ -401,7 +401,7 @@ def __init__(
401401
self.output_types = [out.type for out in outputs]
402402

403403
for override in (lop_overrides, grad_overrides, rop_overrides):
404-
if override == "default":
404+
if override == "default": # type: ignore[comparison-overlap]
405405
raise ValueError(
406406
"'default' is no longer a valid value for overrides. Use None instead."
407407
)
@@ -702,7 +702,7 @@ def _build_and_cache_rop_op(self):
702702
# Return a wrapper that combines connected and disconnected output gradients
703703
def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]:
704704
connected_output_grads = iter(rop_op(*inputs, **kwargs))
705-
all_output_grads = []
705+
all_output_grads: list[Variable | None] = []
706706
for out_grad in output_grads:
707707
if isinstance(out_grad.type, DisconnectedType):
708708
# R_Op does not have DisconnectedType yet, None should be used instead

0 commit comments

Comments
 (0)