Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"paho-mqtt>=1.6.1,<3.0.0",
"construct>=2.10.57,<3",
"vacuum-map-parser-roborock",
"pyrate-limiter>=3.7.0,<4",
"pyrate-limiter>=3.7.0,<5",
"aiomqtt>=2.5.0,<3",
"click-shell~=2.1",
]
Expand Down
58 changes: 26 additions & 32 deletions roborock/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import aiohttp
from aiohttp import ContentTypeError, FormData
from pyrate_limiter import BucketFullException, Duration, Limiter, Rate
from pyrate_limiter import Duration, Limiter, Rate

from roborock import HomeDataSchedule
from roborock.data import HomeData, HomeDataRoom, HomeDataScene, ProductResponse, RRiot, UserData
Expand Down Expand Up @@ -62,7 +62,7 @@ class RoborockApiClient:
Rate(40, Duration.DAY),
]

_login_limiter = Limiter(_LOGIN_RATES, max_delay=1000)
_login_limiter = Limiter(_LOGIN_RATES)
_home_data_limiter = Limiter(_HOME_DATA_RATES)

def __init__(
Expand Down Expand Up @@ -204,11 +204,10 @@ async def add_device(self, user_data: UserData, s: str, t: str) -> dict:
return add_device_response["result"]

async def request_code(self) -> None:
try:
await self._login_limiter.try_acquire_async("login")
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
success = await self._login_limiter.try_acquire_async("login", blocking=False)
if not success:
_LOGGER.info("Rate limit reached for login")
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.")
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
Expand Down Expand Up @@ -238,11 +237,10 @@ async def request_code_v4(self) -> None:
if await self.country_code is None or await self.country is None:
_LOGGER.info("No country code or country found, trying old version of request code.")
return await self.request_code()
try:
await self._login_limiter.try_acquire_async("login")
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
success = await self._login_limiter.try_acquire_async("login", blocking=False)
if not success:
_LOGGER.info("Rate limit reached for login")
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.")
base_url = await self.base_url
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(
Expand Down Expand Up @@ -370,11 +368,10 @@ async def code_login_v4(
return UserData.from_dict(user_data)

async def pass_login(self, password: str) -> UserData:
try:
await self._login_limiter.try_acquire_async("login")
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
success = await self._login_limiter.try_acquire_async("login", blocking=False)
if not success:
_LOGGER.info("Rate limit reached for login")
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.")
base_url = await self.base_url
header_clientid = self._get_header_client_id()

Expand Down Expand Up @@ -468,11 +465,10 @@ async def _get_home_id(self, user_data: UserData):
return home_id_response["data"]["rrHomeId"]

async def get_home_data(self, user_data: UserData) -> HomeData:
try:
self._home_data_limiter.try_acquire("home_data")
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for home data. Please try again later.") from ex
success = self._home_data_limiter.try_acquire("home_data", blocking=False)
if not success:
_LOGGER.info("Rate limit reached for home data")
raise RoborockRateLimit("Reached maximum requests for home data. Please try again later.")
rriot = user_data.rriot
if rriot is None:
raise RoborockException("rriot is none")
Expand All @@ -497,11 +493,10 @@ async def get_home_data(self, user_data: UserData) -> HomeData:

async def get_home_data_v2(self, user_data: UserData) -> HomeData:
"""This is the same as get_home_data, but uses a different endpoint and includes non-robotic vacuums."""
try:
self._home_data_limiter.try_acquire("home_data")
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for home data. Please try again later.") from ex
success = self._home_data_limiter.try_acquire("home_data", blocking=False)
if not success:
_LOGGER.info("Rate limit reached for home data")
raise RoborockRateLimit("Reached maximum requests for home data. Please try again later.")
rriot = user_data.rriot
if rriot is None:
raise RoborockException("rriot is none")
Expand All @@ -526,11 +521,10 @@ async def get_home_data_v2(self, user_data: UserData) -> HomeData:

async def get_home_data_v3(self, user_data: UserData) -> HomeData:
"""This is the same as get_home_data, but uses a different endpoint and includes non-robotic vacuums."""
try:
self._home_data_limiter.try_acquire("home_data")
except BucketFullException as ex:
_LOGGER.info(ex.meta_info)
raise RoborockRateLimit("Reached maximum requests for home data. Please try again later.") from ex
success = self._home_data_limiter.try_acquire("home_data", blocking=False)
if not success:
_LOGGER.info("Rate limit reached for home data")
raise RoborockRateLimit("Reached maximum requests for home data. Please try again later.")
rriot = user_data.rriot
home_id = await self._get_home_id(user_data)
if rriot.r.a is None:
Expand Down
5 changes: 3 additions & 2 deletions tests/fixtures/web_api_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
def skip_rate_limit() -> Generator[None, None, None]:
"""Don't rate limit tests as they aren't actually hitting the api."""
with (
patch("roborock.web_api.RoborockApiClient._login_limiter.try_acquire"),
patch("roborock.web_api.RoborockApiClient._home_data_limiter.try_acquire"),
patch("roborock.web_api.RoborockApiClient._login_limiter.try_acquire", return_value=True),
patch("roborock.web_api.RoborockApiClient._login_limiter.try_acquire_async", return_value=True),
patch("roborock.web_api.RoborockApiClient._home_data_limiter.try_acquire", return_value=True),
):
yield

Expand Down
Loading