Skip to content

Commit 0bc4035

Browse files
committed
Various typing improvements
1 parent e0889db commit 0bc4035

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

pytensor/breakpoint.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pytensor.graph.basic import Apply, Variable
55
from pytensor.graph.op import Op
66
from pytensor.tensor.basic import as_tensor_variable
7+
from pytensor.tensor.type import TensorType
78

89

910
class PdbBreakpoint(Op):
@@ -50,23 +51,26 @@ class PdbBreakpoint(Op):
5051
# as the individual error values
5152
breakpointOp = PdbBreakpoint("MSE too high")
5253
condition = pt.gt(mse.sum(), 100)
53-
mse, monitored_input, monitored_target = breakpointOp(condition, mse,
54-
input, target)
54+
mse, monitored_input, monitored_target = breakpointOp(
55+
condition, mse, input, target
56+
)
5557
5658
# Compile the pytensor function
5759
fct = pytensor.function([input, target], mse)
5860
5961
# Use the function
60-
print fct([10, 0], [10, 5]) # Will NOT activate the breakpoint
61-
print fct([0, 0], [10, 5]) # Will activate the breakpoint
62+
print(fct([10, 0], [10, 5])) # Will NOT activate the breakpoint
63+
print(fct([0, 0], [10, 5])) # Will activate the breakpoint
6264
6365
6466
"""
6567

6668
__props__ = ("name",)
6769

68-
def __init__(self, name):
70+
def __init__(self, name: str):
6971
self.name = name
72+
self.view_map = {}
73+
self.inp_types: list[TensorType] = []
7074

7175
def make_node(self, condition, *monitored_vars):
7276
# Ensure that condition is an PyTensor tensor
@@ -83,13 +87,11 @@ def make_node(self, condition, *monitored_vars):
8387
# (view_map and var_types) in that instance and then apply it on the
8488
# inputs.
8589
new_op = PdbBreakpoint(name=self.name)
86-
new_op.view_map = {}
87-
new_op.inp_types = []
88-
for i in range(len(monitored_vars)):
90+
for i, var in enumerate(monitored_vars):
8991
# Every output i is a view of the input i+1 because of the input
9092
# condition.
9193
new_op.view_map[i] = [i + 1]
92-
new_op.inp_types.append(monitored_vars[i].type)
94+
new_op.inp_types.append(var.type)
9395

9496
# Build the Apply node
9597
inputs = [condition, *monitored_vars]

pytensor/configdefaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _good_seem_param(seed):
101101
return True
102102
try:
103103
int(seed)
104-
except Exception:
104+
except ValueError:
105105
return False
106106
return True
107107

0 commit comments

Comments
 (0)