44from pytensor .graph .basic import Apply , Variable
55from pytensor .graph .op import Op
66from pytensor .tensor .basic import as_tensor_variable
7+ from pytensor .tensor .type import TensorType
78
89
910class PdbBreakpoint (Op ):
@@ -50,23 +51,26 @@ class PdbBreakpoint(Op):
5051 # as the individual error values
5152 breakpointOp = PdbBreakpoint("MSE too high")
5253 condition = pt.gt(mse.sum(), 100)
53- mse, monitored_input, monitored_target = breakpointOp(condition, mse,
54- input, target)
54+ mse, monitored_input, monitored_target = breakpointOp(
55+ condition, mse, input, target
56+ )
5557
5658 # Compile the pytensor function
5759 fct = pytensor.function([input, target], mse)
5860
5961 # Use the function
60- print fct([10, 0], [10, 5]) # Will NOT activate the breakpoint
61- print fct([0, 0], [10, 5]) # Will activate the breakpoint
62+ print( fct([10, 0], [10, 5])) # Will NOT activate the breakpoint
63+ print( fct([0, 0], [10, 5])) # Will activate the breakpoint
6264
6365
6466 """
6567
6668 __props__ = ("name" ,)
6769
68- def __init__ (self , name ):
70+ def __init__ (self , name : str ):
6971 self .name = name
72+ self .view_map = {}
73+ self .inp_types : list [TensorType ] = []
7074
7175 def make_node (self , condition , * monitored_vars ):
7276 # Ensure that condition is an PyTensor tensor
@@ -83,13 +87,11 @@ def make_node(self, condition, *monitored_vars):
8387 # (view_map and var_types) in that instance and then apply it on the
8488 # inputs.
8589 new_op = PdbBreakpoint (name = self .name )
86- new_op .view_map = {}
87- new_op .inp_types = []
88- for i in range (len (monitored_vars )):
90+ for i , var in enumerate (monitored_vars ):
8991 # Every output i is a view of the input i+1 because of the input
9092 # condition.
9193 new_op .view_map [i ] = [i + 1 ]
92- new_op .inp_types .append (monitored_vars [ i ] .type )
94+ new_op .inp_types .append (var .type )
9395
9496 # Build the Apply node
9597 inputs = [condition , * monitored_vars ]
0 commit comments