This is an automated email from the ASF dual-hosted git repository.

rusackas pushed a commit to branch adopt-pr-34525-fix-cache-warmup
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 2ac03e438c540f6e4936be972f3b9cc6a38f3a37
Author: Evan Rusackas <[email protected]>
AuthorDate: Thu Mar 5 09:21:38 2026 -0800

    fix: cache warmup using WebDriver for reliable authentication
    
      Adopted from PR #34525 by @rusackas (originally PR #20387 by @ensky).
      Rebased on master with conflict resolution.
    
      Changes:
      - Use WebDriver (Selenium) to render dashboards for cache warmup
      - Add SUPERSET_CACHE_WARMUP_USER config for specifying the warmup user
      - Support persistent WebDriver instances for efficiency
      - Warm up entire dashboards instead of individual charts
      - Add Celery beat configuration documentation
      - Remove obsolete HTTP-based cache warmup tests
    
      Co-Authored-By: Evan Rusackas <[email protected]>
      Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
---
 docs/admin_docs/configuration/cache.mdx     |  33 +++++
 superset/config.py                          |   3 +
 superset/tasks/cache.py                     | 199 ++++++++++------------------
 superset/utils/screenshots.py               |  12 +-
 superset/utils/webdriver.py                 |  87 +++++++++---
 tests/integration_tests/strategy_tests.py   |  42 ++----
 tests/integration_tests/tasks/test_cache.py |  95 -------------
 tests/integration_tests/tasks/test_utils.py |  77 -----------
 tests/integration_tests/thumbnails_tests.py |  16 +--
 9 files changed, 201 insertions(+), 363 deletions(-)

diff --git a/docs/admin_docs/configuration/cache.mdx 
b/docs/admin_docs/configuration/cache.mdx
index be1459f09f1..0356b7e0a3d 100644
--- a/docs/admin_docs/configuration/cache.mdx
+++ b/docs/admin_docs/configuration/cache.mdx
@@ -86,6 +86,39 @@ instead requires a cachelib object.
 
 See [Async Queries via Celery](/admin-docs/configuration/async-queries-celery) 
for details.
 
+## Celery beat
+
+Superset has a Celery task that will periodically warm up the cache based on 
different strategies.
+To use it, add the following to your `superset_config.py`:
+
+```python
+from celery.schedules import crontab
+from superset.config import CeleryConfig
+
+# User that will be used to authenticate and render dashboards for cache warmup
+SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards"
+
+# Extend the default CeleryConfig to add cache warmup schedule
+class CustomCeleryConfig(CeleryConfig):
+    beat_schedule = {
+        **CeleryConfig.beat_schedule,
+        'cache-warmup-hourly': {
+            'task': 'cache-warmup',
+            'schedule': crontab(minute=0, hour='*'),  # hourly
+            'kwargs': {
+                'strategy_name': 'top_n_dashboards',
+                'top_n': 5,
+                'since': '7 days ago',
+            },
+        },
+    }
+
+CELERY_CONFIG = CustomCeleryConfig
+```
+
+This will cache the top 5 most popular dashboards every hour. For other
+strategies, check the `superset/tasks/cache.py` file.
+
 ## Caching Thumbnails
 
 This is an optional feature that can be turned on by activating its [feature 
flag](/admin-docs/configuration/configuring-superset#feature-flags) on config:
diff --git a/superset/config.py b/superset/config.py
index 30f0f801227..f4ff5d03e6d 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1056,6 +1056,9 @@ THUMBNAIL_CACHE_CONFIG: CacheConfig = {
 }
 THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())
 
+# Cache warmup user
+SUPERSET_CACHE_WARMUP_USER = "admin"
+
 # Time before selenium times out after trying to locate an element on the page 
and wait
 # for that element to load for a screenshot.
 SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds())
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index 0f28c070703..83b1dd67b49 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -17,65 +17,46 @@
 from __future__ import annotations
 
 import logging
-from typing import Any, Optional, TypedDict, Union
-from urllib import request
-from urllib.error import URLError
+from typing import Any, Optional, Union
 
-from celery.beat import SchedulingError
 from celery.utils.log import get_task_logger
 from flask import current_app
+from selenium.common.exceptions import WebDriverException
 from sqlalchemy import and_, func
+from sqlalchemy.orm import selectinload
 
 from superset import db, security_manager
 from superset.extensions import celery_app
 from superset.models.core import Log
 from superset.models.dashboard import Dashboard
-from superset.models.slice import Slice
 from superset.tags.models import Tag, TaggedObject
-from superset.tasks.exceptions import ExecutorNotFoundError, 
InvalidExecutorError
-from superset.tasks.utils import fetch_csrf_token, get_executor
-from superset.utils import json
 from superset.utils.date_parser import parse_human_datetime
-from superset.utils.machine_auth import MachineAuthProvider
-from superset.utils.urls import get_url_path, is_secure_url
+from superset.utils.webdriver import WebDriverSelenium
 
 logger = get_task_logger(__name__)
 logger.setLevel(logging.INFO)
 
 
-class CacheWarmupPayload(TypedDict, total=False):
-    chart_id: int
-    dashboard_id: int | None
-
-
-class CacheWarmupTask(TypedDict):
-    payload: CacheWarmupPayload
-    username: str | None
-
-
-def get_task(chart: Slice, dashboard: Optional[Dashboard] = None) -> 
CacheWarmupTask:
-    """Return task for warming up a given chart/table cache."""
-    executors = current_app.config["CACHE_WARMUP_EXECUTORS"]
-    payload: CacheWarmupPayload = {"chart_id": chart.id}
-    if dashboard:
-        payload["dashboard_id"] = dashboard.id
-
-    username: str | None
-    try:
-        executor = get_executor(executors, chart)
-        username = executor[1]
-    except (ExecutorNotFoundError, InvalidExecutorError):
-        username = None
-
-    return {"payload": payload, "username": username}
+def get_dash_url(dashboard: Dashboard) -> str:
+    """Return external URL for warming up a given dashboard cache."""
+    with current_app.test_request_context():
+        baseurl = (
+            # when running this as an async task, drop the request context with
+            # app.test_request_context()
+            current_app.config.get("WEBDRIVER_BASEURL")
+            or "{SUPERSET_WEBSERVER_PROTOCOL}://"
+            "{SUPERSET_WEBSERVER_ADDRESS}:"
+            "{SUPERSET_WEBSERVER_PORT}".format(**current_app.config)
+        )
+        return f"{baseurl}{dashboard.url}"
 
 
 class Strategy:  # pylint: disable=too-few-public-methods
     """
     A cache warm up strategy.
 
-    Each strategy defines a `get_tasks` method that returns a list of tasks to
-    send to the `/api/v1/chart/warm_up_cache` endpoint.
+    Each strategy defines a `get_urls` method that returns a list of dashboard 
URLs to
+    warm up using WebDriver.
 
     Strategies can be configured in `superset/config.py`:
 
@@ -96,15 +77,16 @@ class Strategy:  # pylint: disable=too-few-public-methods
     def __init__(self) -> None:
         pass
 
-    def get_tasks(self) -> list[CacheWarmupTask]:
-        raise NotImplementedError("Subclasses must implement get_tasks!")
+    def get_urls(self) -> list[str]:
+        raise NotImplementedError("Subclasses must implement get_urls!")
 
 
 class DummyStrategy(Strategy):  # pylint: disable=too-few-public-methods
     """
-    Warm up all charts.
+    Warm up all published dashboards.
 
-    This is a dummy strategy that will fetch all charts. Can be configured by:
+    This is a dummy strategy that will fetch all published dashboards.
+    Can be configured by:
 
         beat_schedule = {
             'cache-warmup-hourly': {
@@ -118,8 +100,16 @@ class DummyStrategy(Strategy):  # pylint: 
disable=too-few-public-methods
 
     name = "dummy"
 
-    def get_tasks(self) -> list[CacheWarmupTask]:
-        return [get_task(chart) for chart in db.session.query(Slice).all()]
+    def get_urls(self) -> list[str]:
+        # Use selectinload to avoid N+1 queries when checking dashboard.slices
+        dashboards = (
+            db.session.query(Dashboard)
+            .options(selectinload(Dashboard.slices))
+            .filter(Dashboard.published.is_(True))
+            .all()
+        )
+
+        return [get_dash_url(dashboard) for dashboard in dashboards if 
dashboard.slices]
 
 
 class TopNDashboardsStrategy(Strategy):  # pylint: 
disable=too-few-public-methods
@@ -147,7 +137,7 @@ class TopNDashboardsStrategy(Strategy):  # pylint: 
disable=too-few-public-method
         self.top_n = top_n
         self.since = parse_human_datetime(since) if since else None
 
-    def get_tasks(self) -> list[CacheWarmupTask]:
+    def get_urls(self) -> list[str]:
         records = (
             db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
             .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
@@ -161,11 +151,7 @@ class TopNDashboardsStrategy(Strategy):  # pylint: 
disable=too-few-public-method
             
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
         )
 
-        return [
-            get_task(chart, dashboard)
-            for dashboard in dashboards
-            for chart in dashboard.slices
-        ]
+        return [get_dash_url(dashboard) for dashboard in dashboards]
 
 
 class DashboardTagsStrategy(Strategy):  # pylint: 
disable=too-few-public-methods
@@ -190,8 +176,8 @@ class DashboardTagsStrategy(Strategy):  # pylint: 
disable=too-few-public-methods
         super().__init__()
         self.tags = tags or []
 
-    def get_tasks(self) -> list[CacheWarmupTask]:
-        tasks = []
+    def get_urls(self) -> list[str]:
+        urls = []
         tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
         tag_ids = [tag.id for tag in tags]
 
@@ -211,73 +197,14 @@ class DashboardTagsStrategy(Strategy):  # pylint: 
disable=too-few-public-methods
             Dashboard.id.in_(dash_ids)
         )
         for dashboard in tagged_dashboards:
-            for chart in dashboard.slices:
-                tasks.append(get_task(chart))
-
-        # add charts that are tagged
-        tagged_objects = (
-            db.session.query(TaggedObject)
-            .filter(
-                and_(
-                    TaggedObject.object_type == "chart",
-                    TaggedObject.tag_id.in_(tag_ids),
-                )
-            )
-            .all()
-        )
-        chart_ids = [tagged_object.object_id for tagged_object in 
tagged_objects]
-        tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
-        for chart in tagged_charts:
-            tasks.append(get_task(chart))
+            urls.append(get_dash_url(dashboard))
 
-        return tasks
+        return urls
 
 
 strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
 
 
-@celery_app.task(name="fetch_url")
-def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
-    """
-    Celery job to fetch url
-    """
-    result = {}
-    try:
-        url = get_url_path("ChartRestApi.warm_up_cache")
-
-        if is_secure_url(url):
-            logger.info("URL '%s' is secure. Adding Referer header.", url)
-            headers.update({"Referer": url})
-
-        # Fetch CSRF token for API request
-        headers.update(fetch_csrf_token(headers))
-
-        logger.info("Fetching %s with payload %s", url, data)
-        req = request.Request(  # noqa: S310
-            url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
-        )
-        response = request.urlopen(  # pylint: disable=consider-using-with  # 
noqa: S310
-            req, timeout=600
-        )
-        logger.info(
-            "Fetched %s with payload %s, status code: %s", url, data, 
response.code
-        )
-        if response.code == 200:
-            result = {"success": data, "response": 
response.read().decode("utf-8")}
-        else:
-            result = {"error": data, "status_code": response.code}
-            logger.error(
-                "Error fetching %s with payload %s, status code: %s",
-                url,
-                data,
-                response.code,
-            )
-    except URLError as err:
-        logger.exception("Error warming up cache!")
-        result = {"error": data, "exception": str(err)}
-    return result
-
-
 @celery_app.task(name="cache-warmup")
 def cache_warmup(
     strategy_name: str, *args: Any, **kwargs: Any
@@ -285,7 +212,7 @@ def cache_warmup(
     """
     Warm up cache.
 
-    This task periodically hits charts to warm up the cache.
+    This task periodically hits dashboards to warm up the cache.
 
     """
     logger.info("Loading strategy")
@@ -307,25 +234,33 @@ def cache_warmup(
         logger.exception(message)
         return message
 
-    results: dict[str, list[str]] = {"scheduled": [], "errors": []}
-    for task in strategy.get_tasks():
-        username = task["username"]
-        payload = json.dumps(task["payload"])
-        if username:
+    results: dict[str, list[str]] = {"success": [], "errors": []}
+
+    user = security_manager.find_user(
+        username=current_app.config["SUPERSET_CACHE_WARMUP_USER"]
+    )
+    if not user:
+        message = (
+            f"Cache warmup user 
'{current_app.config['SUPERSET_CACHE_WARMUP_USER']}' "
+            "not found. Please configure SUPERSET_CACHE_WARMUP_USER with a 
valid "
+            "username."
+        )
+        logger.error(message)
+        return message
+
+    wd = WebDriverSelenium(current_app.config["WEBDRIVER_TYPE"], user=user)
+
+    try:
+        for url in strategy.get_urls():
             try:
-                user = security_manager.get_user_by_username(username)
-                cookies = MachineAuthProvider.get_auth_cookies(user)
-                headers = {
-                    "Cookie": "session=%s" % cookies.get("session", ""),
-                    "Content-Type": "application/json",
-                }
-                logger.info("Scheduling %s", payload)
-                fetch_url.delay(payload, headers)
-                results["scheduled"].append(payload)
-            except SchedulingError:
-                logger.exception("Error scheduling fetch_url for payload: %s", 
payload)
-                results["errors"].append(payload)
-        else:
-            logger.warning("Executor not found for %s", payload)
+                logger.info("Fetching %s", url)
+                wd.get_screenshot(url, "grid-container")
+                results["success"].append(url)
+            except (WebDriverException, Exception) as ex:  # noqa: BLE001
+                logger.exception("Error warming up cache for %s: %s", url, ex)
+                results["errors"].append(url)
+    finally:
+        # Ensure WebDriver is properly cleaned up
+        wd.destroy()
 
     return results
diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py
index 25c302d35b5..6a5487e45a0 100644
--- a/superset/utils/screenshots.py
+++ b/superset/utils/screenshots.py
@@ -33,8 +33,8 @@ from superset.utils.urls import modify_url_query
 from superset.utils.webdriver import (
     ChartStandaloneMode,
     DashboardStandaloneMode,
-    WebDriver,
     WebDriverPlaywright,
+    WebDriverProxy,
     WebDriverSelenium,
     WindowSize,
 )
@@ -188,7 +188,9 @@ class BaseScreenshot:
         self.url = url
         self.screenshot = None
 
-    def driver(self, window_size: WindowSize | None = None) -> WebDriver:
+    def driver(
+        self, window_size: WindowSize | None = None, user: User | None = None
+    ) -> WebDriverProxy:
         window_size = window_size or self.window_size
         if 
feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"):
             # Try to use Playwright if available (supports WebGL/DeckGL, 
unlike Cypress)
@@ -204,13 +206,13 @@ class BaseScreenshot:
             )
 
         # Use Selenium as default/fallback
-        return WebDriverSelenium(self.driver_type, window_size)
+        return WebDriverSelenium(self.driver_type, window_size, user)
 
     def get_screenshot(
         self, user: User, window_size: WindowSize | None = None
     ) -> bytes | None:
-        driver = self.driver(window_size)
-        self.screenshot = driver.get_screenshot(self.url, self.element, user)
+        driver = self.driver(window_size, user)
+        self.screenshot = driver.get_screenshot(self.url, self.element)
         return self.screenshot
 
     def get_cache_key(
diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py
index 416d3f0e5bf..3663829fdf7 100644
--- a/superset/utils/webdriver.py
+++ b/superset/utils/webdriver.py
@@ -159,7 +159,9 @@ class WebDriverProxy(ABC):
         self._screenshot_load_wait = app.config["SCREENSHOT_LOAD_WAIT"]
 
     @abstractmethod
-    def get_screenshot(self, url: str, element_name: str, user: User) -> bytes 
| None:
+    def get_screenshot(
+        self, url: str, element_name: str, user: User | None = None
+    ) -> bytes | None:
         """
         Run webdriver and return a screenshot
         """
@@ -224,7 +226,7 @@ class WebDriverPlaywright(WebDriverProxy):
             return element.screenshot()
 
     def get_screenshot(  # pylint: disable=too-many-locals, 
too-many-statements  # noqa: C901
-        self, url: str, element_name: str, user: User
+        self, url: str, element_name: str, user: User | None = None
     ) -> bytes | None:
         if not PLAYWRIGHT_AVAILABLE:
             logger.info(
@@ -252,7 +254,8 @@ class WebDriverPlaywright(WebDriverProxy):
             context.set_default_timeout(
                 app.config["SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT"]
             )
-            self.auth(user, context)
+            if user:
+                self.auth(user, context)
             page = context.new_page()
             try:
                 page.goto(
@@ -318,7 +321,7 @@ class WebDriverPlaywright(WebDriverProxy):
                 logger.debug(
                     "Taking a PNG screenshot of url %s as user %s",
                     url,
-                    user.username,
+                    user.username if user else "None",
                 )
                 if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
                     unexpected_errors = 
WebDriverPlaywright.find_unexpected_errors(page)
@@ -399,6 +402,29 @@ class WebDriverPlaywright(WebDriverProxy):
 
 
 class WebDriverSelenium(WebDriverProxy):
+    def __init__(
+        self,
+        driver_type: str,
+        window: WindowSize | None = None,
+        user: User | None = None,
+    ):
+        super().__init__(driver_type, window)
+        self._user = user
+        self._driver: WebDriver | None = None
+
+    def __del__(self) -> None:
+        self._destroy()
+
+    @property
+    def driver(self) -> WebDriver:
+        if not self._driver:
+            self._driver = self._create()
+            assert self._driver  # for mypy
+            self._driver.set_window_size(*self._window)
+            if self._user:
+                self._auth(self._user)
+        return self._driver
+
     def _create_firefox_driver(
         self, pixel_density: float
     ) -> tuple[type[WebDriver], type[Service], dict[str, Any]]:
@@ -456,6 +482,22 @@ class WebDriverSelenium(WebDriverProxy):
         return config
 
     def create(self) -> WebDriver:
+        """Create and return the WebDriver instance.
+
+        This is the public interface for creating the driver. It wraps
+        the internal _create method for backward compatibility.
+        """
+        return self._create()
+
+    def destroy(self) -> None:
+        """Destroy the WebDriver instance.
+
+        This is the public interface for cleanup. It wraps the internal
+        _destroy method and should be called when done with the driver.
+        """
+        self._destroy()
+
+    def _create(self) -> WebDriver:
         pixel_density = app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1)
 
         # Get driver class and initial kwargs based on driver type
@@ -516,25 +558,27 @@ class WebDriverSelenium(WebDriverProxy):
         logger.debug("Init selenium driver")
         return driver_class(**kwargs)
 
-    def auth(self, user: User) -> WebDriver:
-        driver = self.create()
+    def _auth(self, user: User) -> WebDriver:
+        driver = self._create()
         return machine_auth_provider_factory.instance.authenticate_webdriver(
             driver, user
         )
 
-    @staticmethod
-    def destroy(driver: WebDriver, tries: int = 2) -> None:
-        """Destroy a driver"""
+    def _destroy(self, tries: int = 2) -> None:
+        """Destroy the persistent driver"""
+        if not self._driver:
+            return
         # This is some very flaky code in selenium. Hence the retries
         # and catch-all exceptions
         try:
-            retry_call(driver.close, max_tries=tries)
+            retry_call(self._driver.close, max_tries=tries)
         except Exception:  # pylint: disable=broad-except  # noqa: S110
             pass
         try:
-            driver.quit()
+            self._driver.quit()
         except Exception:  # pylint: disable=broad-except  # noqa: S110
             pass
+        self._driver = None
 
     @staticmethod
     def find_unexpected_errors(driver: WebDriver) -> list[str]:
@@ -592,9 +636,16 @@ class WebDriverSelenium(WebDriverProxy):
 
         return error_messages
 
-    def get_screenshot(self, url: str, element_name: str, user: User) -> bytes 
| None:  # noqa: C901
-        driver = self.auth(user)
-        driver.set_window_size(*self._window)
+    def get_screenshot(  # noqa: C901
+        self, url: str, element_name: str, user: User | None = None
+    ) -> bytes | None:
+        # If a user is passed explicitly and differs from the stored user,
+        # update and re-authenticate
+        if user and user != self._user:
+            self._user = user
+            if self._driver:
+                self._destroy()
+        driver = self.driver
         driver.get(url)
         img: bytes | None = None
         selenium_headstart = app.config["SCREENSHOT_SELENIUM_HEADSTART"]
@@ -663,7 +714,7 @@ class WebDriverSelenium(WebDriverProxy):
             logger.debug(
                 "Taking a PNG screenshot of url %s as user %s",
                 url,
-                user.username,
+                self._user.username if self._user else "None",
             )
 
             if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
@@ -698,5 +749,9 @@ class WebDriverSelenium(WebDriverProxy):
             logger.warning("exception in webdriver", exc_info=ex)
             raise
         finally:
-            self.destroy(driver, app.config["SCREENSHOT_SELENIUM_RETRIES"])
+            # When used as a persistent driver (e.g., cache warmup),
+            # cleanup is handled externally via destroy().
+            # When used for one-off screenshots, the caller or __del__
+            # handles cleanup.
+            pass
         return img
diff --git a/tests/integration_tests/strategy_tests.py 
b/tests/integration_tests/strategy_tests.py
index 6dc99f501fe..1cbd4b07a32 100644
--- a/tests/integration_tests/strategy_tests.py
+++ b/tests/integration_tests/strategy_tests.py
@@ -82,15 +82,9 @@ class TestCacheWarmUp(SupersetTestCase):
             self.client.get(f"/superset/dashboard/{dash.id}/")
 
         strategy = TopNDashboardsStrategy(1)
-        result = strategy.get_tasks()
-        expected = [
-            {
-                "payload": {"chart_id": chart.id, "dashboard_id": dash.id},
-                "username": "admin",
-            }
-            for chart in dash.slices
-        ]
-        assert len(result) == len(expected)
+        result = sorted(strategy.get_urls())
+        expected = sorted([f"{get_url_host()}{dash.url}"])
+        assert result == expected
 
     def reset_tag(self, tag):
         """Remove associated object from tag, used to reset tests"""
@@ -108,39 +102,27 @@ class TestCacheWarmUp(SupersetTestCase):
         self.reset_tag(tag1)
 
         strategy = DashboardTagsStrategy(["tag1"])
-        assert strategy.get_tasks() == []
+        result = sorted(strategy.get_urls())
+        expected = []
+        assert result == expected
 
         # tag dashboard 'births' with `tag1`
         tag1 = get_tag("tag1", db.session, TagType.custom)
         dash = self.get_dash_by_slug("births")
-        tag1_payloads = [{"chart_id": chart.id} for chart in dash.slices]
+        tag1_urls = [f"{get_url_host()}{dash.url}"]
         tagged_object = TaggedObject(
             tag_id=tag1.id, object_id=dash.id, object_type=ObjectType.dashboard
         )
         db.session.add(tagged_object)
         db.session.commit()
 
-        assert len(strategy.get_tasks()) == len(tag1_payloads)
+        result = sorted(strategy.get_urls())
+        assert result == tag1_urls
 
         strategy = DashboardTagsStrategy(["tag2"])
         tag2 = get_tag("tag2", db.session, TagType.custom)
         self.reset_tag(tag2)
 
-        assert strategy.get_tasks() == []
-
-        # tag first slice
-        dash = self.get_dash_by_slug("unicode-test")
-        chart = dash.slices[0]
-        tag2_payloads = [{"chart_id": chart.id}]
-        object_id = chart.id
-        tagged_object = TaggedObject(
-            tag_id=tag2.id, object_id=object_id, object_type=ObjectType.chart
-        )
-        db.session.add(tagged_object)
-        db.session.commit()
-
-        assert len(strategy.get_tasks()) == len(tag2_payloads)
-
-        strategy = DashboardTagsStrategy(["tag1", "tag2"])
-
-        assert len(strategy.get_tasks()) == len(tag1_payloads + tag2_payloads)
+        result = sorted(strategy.get_urls())
+        expected = []
+        assert result == expected
diff --git a/tests/integration_tests/tasks/test_cache.py 
b/tests/integration_tests/tasks/test_cache.py
deleted file mode 100644
index 368cb1ebf0a..00000000000
--- a/tests/integration_tests/tasks/test_cache.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from unittest import mock
-
-import pytest
-
-from tests.integration_tests.test_app import app
-
-
[email protected](
-    "base_url, expected_referer",
-    [
-        ("http://base-url";, None),
-        ("http://base-url/";, None),
-        ("https://base-url";, "https://base-url/api/v1/chart/warm_up_cache";),
-        ("https://base-url/";, "https://base-url/api/v1/chart/warm_up_cache";),
-    ],
-    ids=[
-        "Without trailing slash (HTTP)",
-        "With trailing slash (HTTP)",
-        "Without trailing slash (HTTPS)",
-        "With trailing slash (HTTPS)",
-    ],
-)
[email protected]("superset.tasks.cache.fetch_csrf_token")
[email protected]("superset.tasks.cache.request.Request")
[email protected]("superset.tasks.cache.request.urlopen")
[email protected]("superset.tasks.cache.is_secure_url")
-def test_fetch_url(
-    mock_is_secure_url,
-    mock_urlopen,
-    mock_request_cls,
-    mock_fetch_csrf_token,
-    base_url,
-    expected_referer,
-):
-    from superset.tasks.cache import fetch_url
-
-    mock_request = mock.MagicMock()
-    mock_request_cls.return_value = mock_request
-
-    mock_urlopen.return_value = mock.MagicMock()
-    mock_urlopen.return_value.code = 200
-
-    # Mock the URL validation to return True for HTTPS and False for HTTP
-    mock_is_secure_url.return_value = base_url.startswith("https")
-
-    initial_headers = {"Cookie": "cookie", "key": "value"}
-    csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"}
-
-    # Conditionally add the Referer header and assert its presence
-    if expected_referer:
-        csrf_headers = csrf_headers | {"Referer": expected_referer}
-        assert csrf_headers["Referer"] == expected_referer
-
-    mock_fetch_csrf_token.return_value = csrf_headers
-
-    app.config["WEBDRIVER_BASEURL"] = base_url
-    data = "data"
-    data_encoded = b"data"
-
-    result = fetch_url(data, initial_headers)
-
-    expected_url = (
-        f"{base_url}/api/v1/chart/warm_up_cache"
-        if not base_url.endswith("/")
-        else f"{base_url}api/v1/chart/warm_up_cache"
-    )
-
-    mock_fetch_csrf_token.assert_called_once_with(initial_headers)
-
-    mock_request_cls.assert_called_once_with(
-        expected_url,  # Use the dynamic URL based on base_url
-        data=data_encoded,
-        headers=csrf_headers,
-        method="PUT",
-    )
-    # assert the same Request object is used
-    mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
-
-    assert data == result["success"]
diff --git a/tests/integration_tests/tasks/test_utils.py 
b/tests/integration_tests/tasks/test_utils.py
deleted file mode 100644
index 29e5f38319c..00000000000
--- a/tests/integration_tests/tasks/test_utils.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from unittest import mock
-
-import pytest
-
-from tests.integration_tests.test_app import app
-
-
[email protected](
-    "base_url",
-    [
-        "http://base-url";,
-        "http://base-url/";,
-        "https://base-url";,
-        "https://base-url/";,
-    ],
-    ids=[
-        "Without trailing slash (HTTP)",
-        "With trailing slash (HTTP)",
-        "Without trailing slash (HTTPS)",
-        "With trailing slash (HTTPS)",
-    ],
-)
[email protected]("superset.tasks.cache.request.Request")
[email protected]("superset.tasks.cache.request.urlopen")
-def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, 
app_context):
-    from superset.tasks.utils import fetch_csrf_token
-
-    mock_request = mock.MagicMock()
-    mock_request_cls.return_value = mock_request
-
-    mock_response = mock.MagicMock()
-    mock_urlopen.return_value.__enter__.return_value = mock_response
-
-    mock_response.status = 200
-    mock_response.read.return_value = b'{"result": "csrf_token"}'
-    mock_response.headers.get_all.return_value = [
-        "session=new_session_cookie",
-        "async-token=websocket_cookie",
-    ]
-
-    app.config["WEBDRIVER_BASEURL"] = base_url
-    headers = {"Cookie": "original_session_cookie"}
-
-    result_headers = fetch_csrf_token(headers)
-
-    expected_url = (
-        f"{base_url}/api/v1/security/csrf_token/"
-        if not base_url.endswith("/")
-        else f"{base_url}api/v1/security/csrf_token/"
-    )
-
-    mock_request_cls.assert_called_with(
-        expected_url,
-        headers=headers,
-        method="GET",
-    )
-
-    assert result_headers["X-CSRF-Token"] == "csrf_token"
-    assert result_headers["Cookie"] == "session=new_session_cookie"  # Updated 
assertion
-    # assert the same Request object is used
-    mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
diff --git a/tests/integration_tests/thumbnails_tests.py 
b/tests/integration_tests/thumbnails_tests.py
index e35cd242601..81596bc186c 100644
--- a/tests/integration_tests/thumbnails_tests.py
+++ b/tests/integration_tests/thumbnails_tests.py
@@ -147,31 +147,31 @@ class TestWebDriverSelenium(SupersetTestCase):
     def test_screenshot_selenium_headstart(
         self, mock_sleep, mock_webdriver, mock_webdriver_wait
     ):
-        webdriver = WebDriverSelenium("firefox")
         user = security_manager.get_user_by_username(ADMIN_USERNAME)
+        webdriver = WebDriverSelenium("firefox", user=user)
         url = get_url_path("Superset.slice", slice_id=1, standalone="true")
         app.config["SCREENSHOT_SELENIUM_HEADSTART"] = 5
-        webdriver.get_screenshot(url, "chart-container", user=user)
+        webdriver.get_screenshot(url, "chart-container")
         assert mock_sleep.call_args_list[0] == call(5)
 
     @patch("superset.utils.webdriver.WebDriverWait")
     @patch("superset.utils.webdriver.firefox")
     def test_screenshot_selenium_locate_wait(self, mock_webdriver, 
mock_webdriver_wait):
         app.config["SCREENSHOT_LOCATE_WAIT"] = 15
-        webdriver = WebDriverSelenium("firefox")
         user = security_manager.get_user_by_username(ADMIN_USERNAME)
+        webdriver = WebDriverSelenium("firefox", user=user)
         url = get_url_path("Superset.slice", slice_id=1, standalone="true")
-        webdriver.get_screenshot(url, "chart-container", user=user)
+        webdriver.get_screenshot(url, "chart-container")
         assert mock_webdriver_wait.call_args_list[0] == call(ANY, 15)
 
     @patch("superset.utils.webdriver.WebDriverWait")
     @patch("superset.utils.webdriver.firefox")
     def test_screenshot_selenium_load_wait(self, mock_webdriver, 
mock_webdriver_wait):
         app.config["SCREENSHOT_LOAD_WAIT"] = 15
-        webdriver = WebDriverSelenium("firefox")
         user = security_manager.get_user_by_username(ADMIN_USERNAME)
+        webdriver = WebDriverSelenium("firefox", user=user)
         url = get_url_path("Superset.slice", slice_id=1, standalone="true")
-        webdriver.get_screenshot(url, "chart-container", user=user)
+        webdriver.get_screenshot(url, "chart-container")
         assert mock_webdriver_wait.call_args_list[2] == call(ANY, 15)
 
     @patch("superset.utils.webdriver.WebDriverWait")
@@ -180,11 +180,11 @@ class TestWebDriverSelenium(SupersetTestCase):
     def test_screenshot_selenium_animation_wait(
         self, mock_sleep, mock_webdriver, mock_webdriver_wait
     ):
-        webdriver = WebDriverSelenium("firefox")
         user = security_manager.get_user_by_username(ADMIN_USERNAME)
+        webdriver = WebDriverSelenium("firefox", user=user)
         url = get_url_path("Superset.slice", slice_id=1, standalone="true")
         app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] = 4
-        webdriver.get_screenshot(url, "chart-container", user=user)
+        webdriver.get_screenshot(url, "chart-container")
         assert mock_sleep.call_args_list[1] == call(4)
 
 


Reply via email to