diff --git a/deepgram/clients/common/v1/abstract_sync_websocket.py b/deepgram/clients/common/v1/abstract_sync_websocket.py index 75c6ab71..6dc44375 100644 --- a/deepgram/clients/common/v1/abstract_sync_websocket.py +++ b/deepgram/clients/common/v1/abstract_sync_websocket.py @@ -4,7 +4,7 @@ import json import time import logging -from typing import Dict, Union, Optional, cast, Any, Callable +from typing import Dict, Union, Optional, cast, Any, Callable, Type from datetime import datetime import threading from abc import ABC, abstractmethod @@ -38,6 +38,12 @@ class AbstractSyncWebSocketClient(ABC): # pylint: disable=too-many-instance-att This class provides methods to establish a WebSocket connection generically for use in all WebSocket clients. + + Args: + config (DeepgramClientOptions): all the options for the client + endpoint (str): the endpoint to connect to + thread_cls (Type[threading.Thread]): optional thread class to use for creating threads, + defaults to threading.Thread. Useful for custom thread management like ContextVar support. """ _logger: verboselogs.VerboseLogger @@ -52,12 +58,19 @@ class AbstractSyncWebSocketClient(ABC): # pylint: disable=too-many-instance-att _listen_thread: Union[threading.Thread, None] _delegate: Optional[Speaker] = None + _thread_cls: Type[threading.Thread] + _kwargs: Optional[Dict] = None _addons: Optional[Dict] = None _options: Optional[Dict] = None _headers: Optional[Dict] = None - def __init__(self, config: DeepgramClientOptions, endpoint: str = ""): + def __init__( + self, + config: DeepgramClientOptions, + endpoint: str = "", + thread_cls: Type[threading.Thread] = threading.Thread, + ): if config is None: raise DeepgramError("Config is required") if endpoint == "": @@ -73,6 +86,8 @@ def __init__(self, config: DeepgramClientOptions, endpoint: str = ""): self._listen_thread = None + self._thread_cls = thread_cls + # exit self._exit_event = threading.Event() @@ -152,7 +167,7 @@ def start( self._delegate.set_push_callback(self._process_message) else: self._logger.notice("create _listening thread") - self._listen_thread = threading.Thread(target=self._listening) + self._listen_thread = self._thread_cls(target=self._listening) self._listen_thread.start() # debug the threads diff --git a/deepgram/clients/listen/v1/websocket/client.py b/deepgram/clients/listen/v1/websocket/client.py index a8743007..e6633689 100644 --- a/deepgram/clients/listen/v1/websocket/client.py +++ b/deepgram/clients/listen/v1/websocket/client.py @@ -4,7 +4,7 @@ import json import time import logging -from typing import Dict, Union, Optional, cast, Any, Callable +from typing import Dict, Union, Optional, cast, Any, Callable, Type from datetime import datetime import threading @@ -38,10 +38,12 @@ class ListenWebSocketClient( """ Client for interacting with Deepgram's live transcription services over WebSockets. - This class provides methods to establish a WebSocket connection for live transcription and handle real-time transcription events. + This class provides methods to establish a WebSocket connection for live transcription and handle real-time transcription events. - Args: - config (DeepgramClientOptions): all the options for the client. + Args: + config (DeepgramClientOptions): all the options for the client. + thread_cls (Type[threading.Thread]): optional thread class to use for creating threads, + defaults to threading.Thread. Useful for custom thread management like ContextVar support. """ _logger: verboselogs.VerboseLogger @@ -55,12 +57,18 @@ class ListenWebSocketClient( _flush_thread: Union[threading.Thread, None] _last_datagram: Optional[datetime] = None + _thread_cls: Type[threading.Thread] + _kwargs: Optional[Dict] = None _addons: Optional[Dict] = None _options: Optional[Dict] = None _headers: Optional[Dict] = None - def __init__(self, config: DeepgramClientOptions): + def __init__( + self, + config: DeepgramClientOptions, + thread_cls: Type[threading.Thread] = threading.Thread, + ): if config is None: raise DeepgramError("Config is required") @@ -78,13 +86,19 @@ def __init__(self, config: DeepgramClientOptions): self._last_datagram = None self._lock_flush = threading.Lock() + self._thread_cls = thread_cls + # init handlers self._event_handlers = { event: [] for event in LiveTranscriptionEvents.__members__.values() } # call the parent constructor - super().__init__(self._config, self._endpoint) + super().__init__( + config=self._config, + endpoint=self._endpoint, + thread_cls=self._thread_cls, + ) # pylint: disable=too-many-statements,too-many-branches def start( @@ -154,7 +168,7 @@ def start( # keepalive thread if self._config.is_keep_alive_enabled(): self._logger.notice("keepalive is enabled") - self._keep_alive_thread = threading.Thread(target=self._keep_alive) + self._keep_alive_thread = self._thread_cls(target=self._keep_alive) self._keep_alive_thread.start() else: self._logger.notice("keepalive is disabled") @@ -162,7 +176,7 @@ def start( # flush thread if self._config.is_auto_flush_reply_enabled(): self._logger.notice("autoflush is enabled") - self._flush_thread = threading.Thread(target=self._flush) + self._flush_thread = self._thread_cls(target=self._flush) self._flush_thread.start() else: self._logger.notice("autoflush is disabled")