Skip to content

Commit f753e06

Browse files
committed
Improve typing in pytensor/tensor/random/op.py
1 parent bdfd81a commit f753e06

File tree

1 file changed

+9
-9
lines changed
  • pytensor/tensor/random

1 file changed

+9
-9
lines changed

pytensor/tensor/random/op.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ class RandomVariable(RNGConsumerOp):
6262

6363
def __init__(
6464
self,
65-
name=None,
65+
name: str | None = None,
6666
ndim_supp=None,
6767
ndims_params=None,
68-
dtype: str | None = None,
69-
inplace=None,
68+
dtype: str | np.dtype | None = None,
69+
inplace: bool | None = None,
7070
signature: str | None = None,
7171
):
7272
"""Create a random variable `Op`.
@@ -115,7 +115,7 @@ def __init__(
115115
if self.signature is not None:
116116
# Assume a single output. Several methods need to be updated to handle multiple outputs.
117117
self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature)
118-
self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig]
118+
self.ndims_params = tuple([len(input_sig) for input_sig in self.inputs_sig])
119119
self.ndim_supp = len(self.output_sig)
120120
else:
121121
if (
@@ -238,7 +238,7 @@ def _infer_shape(
238238

239239
from pytensor.tensor.extra_ops import broadcast_shape_iter
240240

241-
supp_shape: tuple[Any]
241+
supp_shape: tuple[Any, ...]
242242
if self.ndim_supp == 0:
243243
supp_shape = ()
244244
else:
@@ -406,19 +406,19 @@ def make_node(self, rng, size, *dist_params):
406406
def batch_ndim(self, node: Apply) -> int:
407407
return cast(int, node.default_output().type.ndim - self.ndim_supp)
408408

409-
def rng_param(self, node) -> Variable:
409+
def rng_param(self, node: Apply) -> Variable:
410410
"""Return the node input corresponding to the rng"""
411411
return node.inputs[0]
412412

413-
def size_param(self, node) -> Variable:
413+
def size_param(self, node: Apply) -> Variable:
414414
"""Return the node input corresponding to the size"""
415415
return node.inputs[1]
416416

417-
def dist_params(self, node) -> Sequence[Variable]:
417+
def dist_params(self, node: Apply) -> Sequence[Variable]:
418418
"""Return the node inpust corresponding to dist params"""
419419
return node.inputs[2:]
420420

421-
def perform(self, node, inputs, outputs):
421+
def perform(self, node: Apply, inputs, outputs):
422422
rng, size, *args = inputs
423423

424424
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.

0 commit comments

Comments
 (0)