Skip to content

Commit 20e5038

Browse files
committed
Improve typing in pytensor/utils.py
1 parent 4519df3 commit 20e5038

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

pytensor/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections.abc import Iterable, Sequence
1010
from functools import partial
1111
from pathlib import Path
12+
from typing import TypeVar
1213

1314
import numpy as np
1415

@@ -57,6 +58,9 @@
5758
NDARRAY_C_VERSION = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
5859

5960

61+
T = TypeVar("T")
62+
63+
6064
def __call_excepthooks(type, value, trace):
6165
"""
6266
This function is meant to replace excepthook and do some
@@ -205,7 +209,7 @@ def hash_from_code(msg: str | bytes) -> str:
205209
return f"m{hashlib.sha256(msg).hexdigest()}"
206210

207211

208-
def uniq(seq: Sequence) -> list:
212+
def uniq(seq: Sequence[T]) -> list[T]:
209213
"""
210214
Do not use set, this must always return the same value at the same index.
211215
If we just exchange other values, but keep the same pattern of duplication,
@@ -217,7 +221,7 @@ def uniq(seq: Sequence) -> list:
217221
return [x for i, x in enumerate(seq) if seq.index(x) == i]
218222

219223

220-
def difference(seq1: Iterable, seq2: Iterable):
224+
def difference(seq1: Iterable[T], seq2: Iterable[T]) -> list[T]:
221225
r"""
222226
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
223227
@@ -236,7 +240,7 @@ def difference(seq1: Iterable, seq2: Iterable):
236240
return [x for x in seq1 if x not in seq2]
237241

238242

239-
def to_return_values(values):
243+
def to_return_values(values: Sequence[T]) -> T | Sequence[T]:
240244
if len(values) == 1:
241245
return values[0]
242246
else:

0 commit comments

Comments
 (0)