diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index be489f9c9a..ccf89fbe34 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -1056,9 +1056,10 @@ def _enqueue_links_filter_iterator( ) and self._check_url_patterns(target_url, kwargs.get('include'), kwargs.get('exclude')): yield request - limit = limit - 1 if limit is not None else None - if limit and limit <= 0: - break + if limit is not None: + limit -= 1 + if limit <= 0: + break def _check_enqueue_strategy( self, diff --git a/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py b/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py index 04046153c0..1b8b50777b 100644 --- a/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py +++ b/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py @@ -428,3 +428,28 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: assert handler_calls.called assert handler_calls.call_count == 1 + + +async def test_enqueue_links_with_limit(server_url: URL, http_client: HttpClient) -> None: + start_url = str(server_url / 'sub_index') + requests = [start_url] + + crawler = BeautifulSoupCrawler(http_client=http_client) + visit = mock.Mock() + + @crawler.router.default_handler + async def request_handler(context: BeautifulSoupCrawlingContext) -> None: + visit(context.request.url) + await context.enqueue_links(limit=1) + + await crawler.run(requests) + + first_visited = visit.call_args_list[0][0][0] + visited = {call[0][0] for call in visit.call_args_list} + + assert first_visited == start_url + # Only one link should be enqueued from sub_index due to the limit + assert visited == { + start_url, + str(server_url / 'page_3'), + } diff --git a/tests/unit/crawlers/_parsel/test_parsel_crawler.py b/tests/unit/crawlers/_parsel/test_parsel_crawler.py index 65fbd3c303..648b6ee9c0 100644 --- a/tests/unit/crawlers/_parsel/test_parsel_crawler.py +++ b/tests/unit/crawlers/_parsel/test_parsel_crawler.py @@ -445,3 +445,28 @@ async def handler(context: ParselCrawlingContext) -> None: await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) await crawler.run([str(server_url / 'start_enqueue')]) + + +async def test_enqueue_links_with_limit(server_url: URL, http_client: HttpClient) -> None: + start_url = str(server_url / 'sub_index') + requests = [start_url] + + crawler = ParselCrawler(http_client=http_client) + visit = mock.Mock() + + @crawler.router.default_handler + async def request_handler(context: ParselCrawlingContext) -> None: + visit(context.request.url) + await context.enqueue_links(limit=1) + + await crawler.run(requests) + + first_visited = visit.call_args_list[0][0][0] + visited = {call[0][0] for call in visit.call_args_list} + + assert first_visited == start_url + # Only one link should be enqueued from sub_index due to the limit + assert visited == { + start_url, + str(server_url / 'page_3'), + } diff --git a/tests/unit/crawlers/_playwright/test_playwright_crawler.py b/tests/unit/crawlers/_playwright/test_playwright_crawler.py index 2702010ba8..a2e823c195 100644 --- a/tests/unit/crawlers/_playwright/test_playwright_crawler.py +++ b/tests/unit/crawlers/_playwright/test_playwright_crawler.py @@ -1051,3 +1051,28 @@ async def failed_request_handler(context: BasicCrawlingContext, _error: Exceptio } await queue.drop() + + +async def test_enqueue_links_with_limit(server_url: URL) -> None: + start_url = str(server_url / 'sub_index') + requests = [start_url] + + crawler = PlaywrightCrawler() + visit = mock.Mock() + + @crawler.router.default_handler + async def request_handler(context: PlaywrightCrawlingContext) -> None: + visit(context.request.url) + await context.enqueue_links(limit=1) + + await crawler.run(requests) + + first_visited = visit.call_args_list[0][0][0] + visited = {call[0][0] for call in visit.call_args_list} + + assert first_visited == start_url + # Only one link should be enqueued from sub_index due to the limit + assert visited == { + start_url, + str(server_url / 'page_3'), + }