11"""Define new Ops from existing Ops"""
22
33import warnings
4- from collections .abc import Callable , Sequence
4+ from collections .abc import Callable
55from copy import copy
66from functools import partial
77from itertools import chain
8- from typing import Union , cast
8+ from typing import Union
99
1010from pytensor .compile .function import function
1111from pytensor .compile .function .pfunc import rebuild_collect_shared
@@ -88,12 +88,12 @@ def local_traverse(out):
8888
8989
9090def construct_nominal_fgraph (
91- inputs : Sequence [Variable ], outputs : Sequence [Variable ]
91+ inputs : list [Variable ], outputs : list [Variable ]
9292) -> tuple [
9393 FunctionGraph ,
94- Sequence [ Variable ],
95- dict [Variable , Variable ],
96- dict [ Variable , Variable ],
94+ list [ SharedVariable ],
95+ dict [SharedVariable , Variable ],
96+ list [ Variable ],
9797]:
9898 """Construct an inner-`FunctionGraph` with ordered nominal inputs."""
9999 implicit_shared_inputs = []
@@ -119,7 +119,7 @@ def construct_nominal_fgraph(
119119 )
120120
121121 new = rebuild_collect_shared (
122- cast ( Sequence [ Variable ], outputs ) ,
122+ outputs ,
123123 inputs = inputs + implicit_shared_inputs ,
124124 replace = replacements ,
125125 copy_inputs_over = False ,
@@ -401,7 +401,7 @@ def __init__(
401401 self .output_types = [out .type for out in outputs ]
402402
403403 for override in (lop_overrides , grad_overrides , rop_overrides ):
404- if override == "default" :
404+ if override == "default" : # type: ignore[comparison-overlap]
405405 raise ValueError (
406406 "'default' is no longer a valid value for overrides. Use None instead."
407407 )
@@ -702,7 +702,7 @@ def _build_and_cache_rop_op(self):
702702 # Return a wrapper that combines connected and disconnected output gradients
703703 def wrapper (* inputs : Variable , ** kwargs ) -> list [Variable | None ]:
704704 connected_output_grads = iter (rop_op (* inputs , ** kwargs ))
705- all_output_grads = []
705+ all_output_grads : list [ Variable | None ] = []
706706 for out_grad in output_grads :
707707 if isinstance (out_grad .type , DisconnectedType ):
708708 # R_Op does not have DisconnectedType yet, None should be used instead
0 commit comments