From e0c909fab18e851cdb5564295974cea11889f54c Mon Sep 17 00:00:00 2001 From: Hans-Christian Otto Date: Thu, 20 Nov 2025 08:40:30 +0100 Subject: [PATCH] Add support for token refresh --- .../integration_linear/__init__.py | 10 ++ custom_components/integration_linear/api.py | 47 +++++++- custom_components/integration_linear/oauth.py | 109 ++++++++++++++++++ 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 custom_components/integration_linear/oauth.py diff --git a/custom_components/integration_linear/__init__.py b/custom_components/integration_linear/__init__.py index 12a419d..9671249 100644 --- a/custom_components/integration_linear/__init__.py +++ b/custom_components/integration_linear/__init__.py @@ -24,6 +24,7 @@ from .const import CONF_API_TOKEN, DOMAIN, LOGGER from .coordinator import BlueprintDataUpdateCoordinator from .data import IntegrationBlueprintData +from .oauth import async_get_valid_token if TYPE_CHECKING: from homeassistant.core import HomeAssistant, ServiceCall @@ -162,6 +163,7 @@ async def async_setup_entry( # Get API token - either from OAuth or from config entry data api_token: str + token_refresh_callback = None if CONF_API_TOKEN in entry.data: # API key authentication api_token = entry.data[CONF_API_TOKEN] @@ -169,6 +171,13 @@ async def async_setup_entry( # OAuth authentication - token is stored in entry.data token = entry.data.get("token", {}) api_token = token.get("access_token", "") + + # Create token refresh callback for OAuth + async def refresh_token() -> str: + """Refresh OAuth token and return new access token.""" + return await async_get_valid_token(hass, entry) + + token_refresh_callback = refresh_token coordinator = BlueprintDataUpdateCoordinator( hass=hass, @@ -180,6 +189,7 @@ async def async_setup_entry( client=IntegrationBlueprintApiClient( api_token=api_token, session=async_get_clientsession(hass), + token_refresh_callback=token_refresh_callback, ), integration=async_get_loaded_integration(hass, entry.domain), coordinator=coordinator, diff --git a/custom_components/integration_linear/api.py b/custom_components/integration_linear/api.py index 2bd5119..53514b0 100644 --- a/custom_components/integration_linear/api.py +++ b/custom_components/integration_linear/api.py @@ -3,7 +3,7 @@ from __future__ import annotations import socket -from typing import Any +from typing import Any, Callable, Awaitable import aiohttp import async_timeout @@ -61,10 +61,13 @@ def __init__( self, api_token: str, session: aiohttp.ClientSession, + token_refresh_callback: Callable[[], Awaitable[str]] | None = None, ) -> None: """Initialize Linear API Client.""" self._api_token = api_token self._session = session + self._token_refresh_callback = token_refresh_callback + self._refresh_in_progress = False async def async_validate_token(self) -> None: """Validate the API token by making a simple query.""" @@ -550,6 +553,7 @@ async def _graphql_query(self, query: str, variables: dict | None = None) -> Any "Authorization": self._api_token, "Content-Type": "application/json", }, + retry_on_auth_error=True, ) async def _api_wrapper( @@ -558,6 +562,7 @@ async def _api_wrapper( url: str, data: dict | None = None, headers: dict | None = None, + retry_on_auth_error: bool = False, ) -> Any: """Get information from the API.""" try: @@ -573,6 +578,46 @@ async def _api_wrapper( result = await response.json() LOGGER.debug("Response: %r", result) + # Check for authentication errors + is_auth_error = False + if response.status in (HTTP_STATUS_UNAUTHORIZED, HTTP_STATUS_FORBIDDEN): + is_auth_error = True + elif response.status >= HTTP_STATUS_BAD_REQUEST and "errors" in result: + error_messages = [] + for err in result["errors"]: + message = err.get("message", "Unknown error") + error_messages.append(message) + extensions = err.get("extensions", {}) + status_code = extensions.get("statusCode") + if status_code in (401, 403) or "unauthorized" in message.lower(): + is_auth_error = True + break + + # Try to refresh token if we have a callback and this is an auth error + if is_auth_error and retry_on_auth_error and self._token_refresh_callback and not self._refresh_in_progress: + LOGGER.info("Authentication error detected, attempting token refresh") + try: + self._refresh_in_progress = True + new_token = await self._token_refresh_callback() + self._api_token = new_token + # Create new headers dict with updated token + retry_headers = dict(headers) if headers else {} + retry_headers["Authorization"] = new_token + # Retry the request once + LOGGER.debug("Retrying request with refreshed token") + response = await self._session.request( + method=method, + url=url, + headers=retry_headers, + json=data, + ) + result = await response.json() + LOGGER.debug("Response after retry: %r", result) + except Exception as refresh_exception: + LOGGER.error("Token refresh failed: %s", refresh_exception) + _raise_authentication_error() + finally: + self._refresh_in_progress = False # Check for HTTP errors if response.status in (HTTP_STATUS_UNAUTHORIZED, HTTP_STATUS_FORBIDDEN): diff --git a/custom_components/integration_linear/oauth.py b/custom_components/integration_linear/oauth.py new file mode 100644 index 0000000..d520b52 --- /dev/null +++ b/custom_components/integration_linear/oauth.py @@ -0,0 +1,109 @@ +"""OAuth2 token refresh helper for Linear Integration.""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any + +from homeassistant.helpers.config_entry_oauth2_flow import ( + LocalOAuth2ImplementationWithPkce, +) + +from .config_flow import LINEAR_AUTHORIZE_URL, LINEAR_CLIENT_ID, LINEAR_TOKEN_URL +from .const import DOMAIN, LOGGER + +if TYPE_CHECKING: + from homeassistant.config_entries import ConfigEntry + from homeassistant.core import HomeAssistant + + +async def async_get_valid_token( + hass: HomeAssistant, + entry: ConfigEntry, +) -> str: + """ + Get a valid access token, refreshing if necessary. + + Args: + hass: Home Assistant instance + entry: Config entry containing OAuth token + + Returns: + Valid access token + + Raises: + ValueError: If entry doesn't use OAuth or token refresh fails + + """ + # Check if this entry uses OAuth (has token in data, not CONF_API_TOKEN) + if "token" not in entry.data: + msg = "Entry does not use OAuth authentication" + raise ValueError(msg) + + token = entry.data["token"] + access_token = token.get("access_token", "") + + # Check if token is expired or about to expire (within 60 seconds) + expires_at = token.get("expires_at", 0) + if expires_at and time.time() >= (expires_at - 60): + # Token is expired or about to expire, refresh it + LOGGER.debug("Token expired or about to expire, refreshing") + token = await async_refresh_token(hass, entry) + + return token.get("access_token", access_token) + + +async def async_refresh_token( + hass: HomeAssistant, + entry: ConfigEntry, +) -> dict[str, Any]: + """ + Refresh the OAuth token and update the config entry. + + Args: + hass: Home Assistant instance + entry: Config entry containing OAuth token + + Returns: + Updated token data + + Raises: + ValueError: If entry doesn't use OAuth or token refresh fails + + """ + # Check if this entry uses OAuth (has token in data, not CONF_API_TOKEN) + if "token" not in entry.data: + msg = "Entry does not use OAuth authentication" + raise ValueError(msg) + + # Get current token + current_token = entry.data["token"] + + # Create the OAuth2 implementation (same as in config flow) + # This is a local implementation, so we recreate it + implementation = LocalOAuth2ImplementationWithPkce( + hass, + DOMAIN, + LINEAR_CLIENT_ID, + authorize_url=LINEAR_AUTHORIZE_URL, + token_url=LINEAR_TOKEN_URL, + client_secret="", # Empty for PKCE public client + code_verifier_length=128, + ) + + # Refresh the token using the implementation + try: + new_token = await implementation.async_refresh_token(current_token) + except Exception as exception: + LOGGER.error("Failed to refresh OAuth token: %s", exception) + error_msg = f"Token refresh failed: {exception}" + raise ValueError(error_msg) from exception + else: + # Update the config entry with the new token + entry_data = dict(entry.data) + entry_data["token"] = new_token + hass.config_entries.async_update_entry(entry, data=entry_data) + + LOGGER.debug("OAuth token refreshed successfully") + return new_token +