This is an automated email from the ASF dual-hosted git repository. rusackas pushed a commit to branch adopt-pr-34525-fix-cache-warmup in repository https://gitbox.apache.org/repos/asf/superset.git
commit 2ac03e438c540f6e4936be972f3b9cc6a38f3a37 Author: Evan Rusackas <[email protected]> AuthorDate: Thu Mar 5 09:21:38 2026 -0800 fix: cache warmup using WebDriver for reliable authentication Adopted from PR #34525 by @rusackas (originally PR #20387 by @ensky). Rebased on master with conflict resolution. Changes: - Use WebDriver (Selenium) to render dashboards for cache warmup - Add SUPERSET_CACHE_WARMUP_USER config for specifying the warmup user - Support persistent WebDriver instances for efficiency - Warm up entire dashboards instead of individual charts - Add Celery beat configuration documentation - Remove obsolete HTTP-based cache warmup tests Co-Authored-By: Evan Rusackas <[email protected]> Co-Authored-By: Claude Sonnet 4.6 <[email protected]> --- docs/admin_docs/configuration/cache.mdx | 33 +++++ superset/config.py | 3 + superset/tasks/cache.py | 199 ++++++++++------------------ superset/utils/screenshots.py | 12 +- superset/utils/webdriver.py | 87 +++++++++--- tests/integration_tests/strategy_tests.py | 42 ++---- tests/integration_tests/tasks/test_cache.py | 95 ------------- tests/integration_tests/tasks/test_utils.py | 77 ----------- tests/integration_tests/thumbnails_tests.py | 16 +-- 9 files changed, 201 insertions(+), 363 deletions(-) diff --git a/docs/admin_docs/configuration/cache.mdx b/docs/admin_docs/configuration/cache.mdx index be1459f09f1..0356b7e0a3d 100644 --- a/docs/admin_docs/configuration/cache.mdx +++ b/docs/admin_docs/configuration/cache.mdx @@ -86,6 +86,39 @@ instead requires a cachelib object. See [Async Queries via Celery](/admin-docs/configuration/async-queries-celery) for details. +## Celery beat + +Superset has a Celery task that will periodically warm up the cache based on different strategies. +To use it, add the following to your `superset_config.py`: + +```python +from celery.schedules import crontab +from superset.config import CeleryConfig + +# User that will be used to authenticate and render dashboards for cache warmup +SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards" + +# Extend the default CeleryConfig to add cache warmup schedule +class CustomCeleryConfig(CeleryConfig): + beat_schedule = { + **CeleryConfig.beat_schedule, + 'cache-warmup-hourly': { + 'task': 'cache-warmup', + 'schedule': crontab(minute=0, hour='*'), # hourly + 'kwargs': { + 'strategy_name': 'top_n_dashboards', + 'top_n': 5, + 'since': '7 days ago', + }, + }, + } + +CELERY_CONFIG = CustomCeleryConfig +``` + +This will cache the top 5 most popular dashboards every hour. For other +strategies, check the `superset/tasks/cache.py` file. + ## Caching Thumbnails This is an optional feature that can be turned on by activating its [feature flag](/admin-docs/configuration/configuring-superset#feature-flags) on config: diff --git a/superset/config.py b/superset/config.py index 30f0f801227..f4ff5d03e6d 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1056,6 +1056,9 @@ THUMBNAIL_CACHE_CONFIG: CacheConfig = { } THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds()) +# Cache warmup user +SUPERSET_CACHE_WARMUP_USER = "admin" + # Time before selenium times out after trying to locate an element on the page and wait # for that element to load for a screenshot. SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds()) diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 0f28c070703..83b1dd67b49 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -17,65 +17,46 @@ from __future__ import annotations import logging -from typing import Any, Optional, TypedDict, Union -from urllib import request -from urllib.error import URLError +from typing import Any, Optional, Union -from celery.beat import SchedulingError from celery.utils.log import get_task_logger from flask import current_app +from selenium.common.exceptions import WebDriverException from sqlalchemy import and_, func +from sqlalchemy.orm import selectinload from superset import db, security_manager from superset.extensions import celery_app from superset.models.core import Log from superset.models.dashboard import Dashboard -from superset.models.slice import Slice from superset.tags.models import Tag, TaggedObject -from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError -from superset.tasks.utils import fetch_csrf_token, get_executor -from superset.utils import json from superset.utils.date_parser import parse_human_datetime -from superset.utils.machine_auth import MachineAuthProvider -from superset.utils.urls import get_url_path, is_secure_url +from superset.utils.webdriver import WebDriverSelenium logger = get_task_logger(__name__) logger.setLevel(logging.INFO) -class CacheWarmupPayload(TypedDict, total=False): - chart_id: int - dashboard_id: int | None - - -class CacheWarmupTask(TypedDict): - payload: CacheWarmupPayload - username: str | None - - -def get_task(chart: Slice, dashboard: Optional[Dashboard] = None) -> CacheWarmupTask: - """Return task for warming up a given chart/table cache.""" - executors = current_app.config["CACHE_WARMUP_EXECUTORS"] - payload: CacheWarmupPayload = {"chart_id": chart.id} - if dashboard: - payload["dashboard_id"] = dashboard.id - - username: str | None - try: - executor = get_executor(executors, chart) - username = executor[1] - except (ExecutorNotFoundError, InvalidExecutorError): - username = None - - return {"payload": payload, "username": username} +def get_dash_url(dashboard: Dashboard) -> str: + """Return external URL for warming up a given dashboard cache.""" + with current_app.test_request_context(): + baseurl = ( + # when running this as an async task, drop the request context with + # app.test_request_context() + current_app.config.get("WEBDRIVER_BASEURL") + or "{SUPERSET_WEBSERVER_PROTOCOL}://" + "{SUPERSET_WEBSERVER_ADDRESS}:" + "{SUPERSET_WEBSERVER_PORT}".format(**current_app.config) + ) + return f"{baseurl}{dashboard.url}" class Strategy: # pylint: disable=too-few-public-methods """ A cache warm up strategy. - Each strategy defines a `get_tasks` method that returns a list of tasks to - send to the `/api/v1/chart/warm_up_cache` endpoint. + Each strategy defines a `get_urls` method that returns a list of dashboard URLs to + warm up using WebDriver. Strategies can be configured in `superset/config.py`: @@ -96,15 +77,16 @@ class Strategy: # pylint: disable=too-few-public-methods def __init__(self) -> None: pass - def get_tasks(self) -> list[CacheWarmupTask]: - raise NotImplementedError("Subclasses must implement get_tasks!") + def get_urls(self) -> list[str]: + raise NotImplementedError("Subclasses must implement get_urls!") class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods """ - Warm up all charts. + Warm up all published dashboards. - This is a dummy strategy that will fetch all charts. Can be configured by: + This is a dummy strategy that will fetch all published dashboards. + Can be configured by: beat_schedule = { 'cache-warmup-hourly': { @@ -118,8 +100,16 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods name = "dummy" - def get_tasks(self) -> list[CacheWarmupTask]: - return [get_task(chart) for chart in db.session.query(Slice).all()] + def get_urls(self) -> list[str]: + # Use selectinload to avoid N+1 queries when checking dashboard.slices + dashboards = ( + db.session.query(Dashboard) + .options(selectinload(Dashboard.slices)) + .filter(Dashboard.published.is_(True)) + .all() + ) + + return [get_dash_url(dashboard) for dashboard in dashboards if dashboard.slices] class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods @@ -147,7 +137,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method self.top_n = top_n self.since = parse_human_datetime(since) if since else None - def get_tasks(self) -> list[CacheWarmupTask]: + def get_urls(self) -> list[str]: records = ( db.session.query(Log.dashboard_id, func.count(Log.dashboard_id)) .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since)) @@ -161,11 +151,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() ) - return [ - get_task(chart, dashboard) - for dashboard in dashboards - for chart in dashboard.slices - ] + return [get_dash_url(dashboard) for dashboard in dashboards] class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods @@ -190,8 +176,8 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods super().__init__() self.tags = tags or [] - def get_tasks(self) -> list[CacheWarmupTask]: - tasks = [] + def get_urls(self) -> list[str]: + urls = [] tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all() tag_ids = [tag.id for tag in tags] @@ -211,73 +197,14 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods Dashboard.id.in_(dash_ids) ) for dashboard in tagged_dashboards: - for chart in dashboard.slices: - tasks.append(get_task(chart)) - - # add charts that are tagged - tagged_objects = ( - db.session.query(TaggedObject) - .filter( - and_( - TaggedObject.object_type == "chart", - TaggedObject.tag_id.in_(tag_ids), - ) - ) - .all() - ) - chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] - tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids)) - for chart in tagged_charts: - tasks.append(get_task(chart)) + urls.append(get_dash_url(dashboard)) - return tasks + return urls strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy] -@celery_app.task(name="fetch_url") -def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]: - """ - Celery job to fetch url - """ - result = {} - try: - url = get_url_path("ChartRestApi.warm_up_cache") - - if is_secure_url(url): - logger.info("URL '%s' is secure. Adding Referer header.", url) - headers.update({"Referer": url}) - - # Fetch CSRF token for API request - headers.update(fetch_csrf_token(headers)) - - logger.info("Fetching %s with payload %s", url, data) - req = request.Request( # noqa: S310 - url, data=bytes(data, "utf-8"), headers=headers, method="PUT" - ) - response = request.urlopen( # pylint: disable=consider-using-with # noqa: S310 - req, timeout=600 - ) - logger.info( - "Fetched %s with payload %s, status code: %s", url, data, response.code - ) - if response.code == 200: - result = {"success": data, "response": response.read().decode("utf-8")} - else: - result = {"error": data, "status_code": response.code} - logger.error( - "Error fetching %s with payload %s, status code: %s", - url, - data, - response.code, - ) - except URLError as err: - logger.exception("Error warming up cache!") - result = {"error": data, "exception": str(err)} - return result - - @celery_app.task(name="cache-warmup") def cache_warmup( strategy_name: str, *args: Any, **kwargs: Any @@ -285,7 +212,7 @@ def cache_warmup( """ Warm up cache. - This task periodically hits charts to warm up the cache. + This task periodically hits dashboards to warm up the cache. """ logger.info("Loading strategy") @@ -307,25 +234,33 @@ def cache_warmup( logger.exception(message) return message - results: dict[str, list[str]] = {"scheduled": [], "errors": []} - for task in strategy.get_tasks(): - username = task["username"] - payload = json.dumps(task["payload"]) - if username: + results: dict[str, list[str]] = {"success": [], "errors": []} + + user = security_manager.find_user( + username=current_app.config["SUPERSET_CACHE_WARMUP_USER"] + ) + if not user: + message = ( + f"Cache warmup user '{current_app.config['SUPERSET_CACHE_WARMUP_USER']}' " + "not found. Please configure SUPERSET_CACHE_WARMUP_USER with a valid " + "username." + ) + logger.error(message) + return message + + wd = WebDriverSelenium(current_app.config["WEBDRIVER_TYPE"], user=user) + + try: + for url in strategy.get_urls(): try: - user = security_manager.get_user_by_username(username) - cookies = MachineAuthProvider.get_auth_cookies(user) - headers = { - "Cookie": "session=%s" % cookies.get("session", ""), - "Content-Type": "application/json", - } - logger.info("Scheduling %s", payload) - fetch_url.delay(payload, headers) - results["scheduled"].append(payload) - except SchedulingError: - logger.exception("Error scheduling fetch_url for payload: %s", payload) - results["errors"].append(payload) - else: - logger.warning("Executor not found for %s", payload) + logger.info("Fetching %s", url) + wd.get_screenshot(url, "grid-container") + results["success"].append(url) + except (WebDriverException, Exception) as ex: # noqa: BLE001 + logger.exception("Error warming up cache for %s: %s", url, ex) + results["errors"].append(url) + finally: + # Ensure WebDriver is properly cleaned up + wd.destroy() return results diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index 25c302d35b5..6a5487e45a0 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -33,8 +33,8 @@ from superset.utils.urls import modify_url_query from superset.utils.webdriver import ( ChartStandaloneMode, DashboardStandaloneMode, - WebDriver, WebDriverPlaywright, + WebDriverProxy, WebDriverSelenium, WindowSize, ) @@ -188,7 +188,9 @@ class BaseScreenshot: self.url = url self.screenshot = None - def driver(self, window_size: WindowSize | None = None) -> WebDriver: + def driver( + self, window_size: WindowSize | None = None, user: User | None = None + ) -> WebDriverProxy: window_size = window_size or self.window_size if feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"): # Try to use Playwright if available (supports WebGL/DeckGL, unlike Cypress) @@ -204,13 +206,13 @@ class BaseScreenshot: ) # Use Selenium as default/fallback - return WebDriverSelenium(self.driver_type, window_size) + return WebDriverSelenium(self.driver_type, window_size, user) def get_screenshot( self, user: User, window_size: WindowSize | None = None ) -> bytes | None: - driver = self.driver(window_size) - self.screenshot = driver.get_screenshot(self.url, self.element, user) + driver = self.driver(window_size, user) + self.screenshot = driver.get_screenshot(self.url, self.element) return self.screenshot def get_cache_key( diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py index 416d3f0e5bf..3663829fdf7 100644 --- a/superset/utils/webdriver.py +++ b/superset/utils/webdriver.py @@ -159,7 +159,9 @@ class WebDriverProxy(ABC): self._screenshot_load_wait = app.config["SCREENSHOT_LOAD_WAIT"] @abstractmethod - def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: + def get_screenshot( + self, url: str, element_name: str, user: User | None = None + ) -> bytes | None: """ Run webdriver and return a screenshot """ @@ -224,7 +226,7 @@ class WebDriverPlaywright(WebDriverProxy): return element.screenshot() def get_screenshot( # pylint: disable=too-many-locals, too-many-statements # noqa: C901 - self, url: str, element_name: str, user: User + self, url: str, element_name: str, user: User | None = None ) -> bytes | None: if not PLAYWRIGHT_AVAILABLE: logger.info( @@ -252,7 +254,8 @@ class WebDriverPlaywright(WebDriverProxy): context.set_default_timeout( app.config["SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT"] ) - self.auth(user, context) + if user: + self.auth(user, context) page = context.new_page() try: page.goto( @@ -318,7 +321,7 @@ class WebDriverPlaywright(WebDriverProxy): logger.debug( "Taking a PNG screenshot of url %s as user %s", url, - user.username, + user.username if user else "None", ) if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]: unexpected_errors = WebDriverPlaywright.find_unexpected_errors(page) @@ -399,6 +402,29 @@ class WebDriverPlaywright(WebDriverProxy): class WebDriverSelenium(WebDriverProxy): + def __init__( + self, + driver_type: str, + window: WindowSize | None = None, + user: User | None = None, + ): + super().__init__(driver_type, window) + self._user = user + self._driver: WebDriver | None = None + + def __del__(self) -> None: + self._destroy() + + @property + def driver(self) -> WebDriver: + if not self._driver: + self._driver = self._create() + assert self._driver # for mypy + self._driver.set_window_size(*self._window) + if self._user: + self._auth(self._user) + return self._driver + def _create_firefox_driver( self, pixel_density: float ) -> tuple[type[WebDriver], type[Service], dict[str, Any]]: @@ -456,6 +482,22 @@ class WebDriverSelenium(WebDriverProxy): return config def create(self) -> WebDriver: + """Create and return the WebDriver instance. + + This is the public interface for creating the driver. It wraps + the internal _create method for backward compatibility. + """ + return self._create() + + def destroy(self) -> None: + """Destroy the WebDriver instance. + + This is the public interface for cleanup. It wraps the internal + _destroy method and should be called when done with the driver. + """ + self._destroy() + + def _create(self) -> WebDriver: pixel_density = app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1) # Get driver class and initial kwargs based on driver type @@ -516,25 +558,27 @@ class WebDriverSelenium(WebDriverProxy): logger.debug("Init selenium driver") return driver_class(**kwargs) - def auth(self, user: User) -> WebDriver: - driver = self.create() + def _auth(self, user: User) -> WebDriver: + driver = self._create() return machine_auth_provider_factory.instance.authenticate_webdriver( driver, user ) - @staticmethod - def destroy(driver: WebDriver, tries: int = 2) -> None: - """Destroy a driver""" + def _destroy(self, tries: int = 2) -> None: + """Destroy the persistent driver""" + if not self._driver: + return # This is some very flaky code in selenium. Hence the retries # and catch-all exceptions try: - retry_call(driver.close, max_tries=tries) + retry_call(self._driver.close, max_tries=tries) except Exception: # pylint: disable=broad-except # noqa: S110 pass try: - driver.quit() + self._driver.quit() except Exception: # pylint: disable=broad-except # noqa: S110 pass + self._driver = None @staticmethod def find_unexpected_errors(driver: WebDriver) -> list[str]: @@ -592,9 +636,16 @@ class WebDriverSelenium(WebDriverProxy): return error_messages - def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901 - driver = self.auth(user) - driver.set_window_size(*self._window) + def get_screenshot( # noqa: C901 + self, url: str, element_name: str, user: User | None = None + ) -> bytes | None: + # If a user is passed explicitly and differs from the stored user, + # update and re-authenticate + if user and user != self._user: + self._user = user + if self._driver: + self._destroy() + driver = self.driver driver.get(url) img: bytes | None = None selenium_headstart = app.config["SCREENSHOT_SELENIUM_HEADSTART"] @@ -663,7 +714,7 @@ class WebDriverSelenium(WebDriverProxy): logger.debug( "Taking a PNG screenshot of url %s as user %s", url, - user.username, + self._user.username if self._user else "None", ) if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]: @@ -698,5 +749,9 @@ class WebDriverSelenium(WebDriverProxy): logger.warning("exception in webdriver", exc_info=ex) raise finally: - self.destroy(driver, app.config["SCREENSHOT_SELENIUM_RETRIES"]) + # When used as a persistent driver (e.g., cache warmup), + # cleanup is handled externally via destroy(). + # When used for one-off screenshots, the caller or __del__ + # handles cleanup. + pass return img diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py index 6dc99f501fe..1cbd4b07a32 100644 --- a/tests/integration_tests/strategy_tests.py +++ b/tests/integration_tests/strategy_tests.py @@ -82,15 +82,9 @@ class TestCacheWarmUp(SupersetTestCase): self.client.get(f"/superset/dashboard/{dash.id}/") strategy = TopNDashboardsStrategy(1) - result = strategy.get_tasks() - expected = [ - { - "payload": {"chart_id": chart.id, "dashboard_id": dash.id}, - "username": "admin", - } - for chart in dash.slices - ] - assert len(result) == len(expected) + result = sorted(strategy.get_urls()) + expected = sorted([f"{get_url_host()}{dash.url}"]) + assert result == expected def reset_tag(self, tag): """Remove associated object from tag, used to reset tests""" @@ -108,39 +102,27 @@ class TestCacheWarmUp(SupersetTestCase): self.reset_tag(tag1) strategy = DashboardTagsStrategy(["tag1"]) - assert strategy.get_tasks() == [] + result = sorted(strategy.get_urls()) + expected = [] + assert result == expected # tag dashboard 'births' with `tag1` tag1 = get_tag("tag1", db.session, TagType.custom) dash = self.get_dash_by_slug("births") - tag1_payloads = [{"chart_id": chart.id} for chart in dash.slices] + tag1_urls = [f"{get_url_host()}{dash.url}"] tagged_object = TaggedObject( tag_id=tag1.id, object_id=dash.id, object_type=ObjectType.dashboard ) db.session.add(tagged_object) db.session.commit() - assert len(strategy.get_tasks()) == len(tag1_payloads) + result = sorted(strategy.get_urls()) + assert result == tag1_urls strategy = DashboardTagsStrategy(["tag2"]) tag2 = get_tag("tag2", db.session, TagType.custom) self.reset_tag(tag2) - assert strategy.get_tasks() == [] - - # tag first slice - dash = self.get_dash_by_slug("unicode-test") - chart = dash.slices[0] - tag2_payloads = [{"chart_id": chart.id}] - object_id = chart.id - tagged_object = TaggedObject( - tag_id=tag2.id, object_id=object_id, object_type=ObjectType.chart - ) - db.session.add(tagged_object) - db.session.commit() - - assert len(strategy.get_tasks()) == len(tag2_payloads) - - strategy = DashboardTagsStrategy(["tag1", "tag2"]) - - assert len(strategy.get_tasks()) == len(tag1_payloads + tag2_payloads) + result = sorted(strategy.get_urls()) + expected = [] + assert result == expected diff --git a/tests/integration_tests/tasks/test_cache.py b/tests/integration_tests/tasks/test_cache.py deleted file mode 100644 index 368cb1ebf0a..00000000000 --- a/tests/integration_tests/tasks/test_cache.py +++ /dev/null @@ -1,95 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest import mock - -import pytest - -from tests.integration_tests.test_app import app - - [email protected]( - "base_url, expected_referer", - [ - ("http://base-url", None), - ("http://base-url/", None), - ("https://base-url", "https://base-url/api/v1/chart/warm_up_cache"), - ("https://base-url/", "https://base-url/api/v1/chart/warm_up_cache"), - ], - ids=[ - "Without trailing slash (HTTP)", - "With trailing slash (HTTP)", - "Without trailing slash (HTTPS)", - "With trailing slash (HTTPS)", - ], -) [email protected]("superset.tasks.cache.fetch_csrf_token") [email protected]("superset.tasks.cache.request.Request") [email protected]("superset.tasks.cache.request.urlopen") [email protected]("superset.tasks.cache.is_secure_url") -def test_fetch_url( - mock_is_secure_url, - mock_urlopen, - mock_request_cls, - mock_fetch_csrf_token, - base_url, - expected_referer, -): - from superset.tasks.cache import fetch_url - - mock_request = mock.MagicMock() - mock_request_cls.return_value = mock_request - - mock_urlopen.return_value = mock.MagicMock() - mock_urlopen.return_value.code = 200 - - # Mock the URL validation to return True for HTTPS and False for HTTP - mock_is_secure_url.return_value = base_url.startswith("https") - - initial_headers = {"Cookie": "cookie", "key": "value"} - csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"} - - # Conditionally add the Referer header and assert its presence - if expected_referer: - csrf_headers = csrf_headers | {"Referer": expected_referer} - assert csrf_headers["Referer"] == expected_referer - - mock_fetch_csrf_token.return_value = csrf_headers - - app.config["WEBDRIVER_BASEURL"] = base_url - data = "data" - data_encoded = b"data" - - result = fetch_url(data, initial_headers) - - expected_url = ( - f"{base_url}/api/v1/chart/warm_up_cache" - if not base_url.endswith("/") - else f"{base_url}api/v1/chart/warm_up_cache" - ) - - mock_fetch_csrf_token.assert_called_once_with(initial_headers) - - mock_request_cls.assert_called_once_with( - expected_url, # Use the dynamic URL based on base_url - data=data_encoded, - headers=csrf_headers, - method="PUT", - ) - # assert the same Request object is used - mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY) - - assert data == result["success"] diff --git a/tests/integration_tests/tasks/test_utils.py b/tests/integration_tests/tasks/test_utils.py deleted file mode 100644 index 29e5f38319c..00000000000 --- a/tests/integration_tests/tasks/test_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest import mock - -import pytest - -from tests.integration_tests.test_app import app - - [email protected]( - "base_url", - [ - "http://base-url", - "http://base-url/", - "https://base-url", - "https://base-url/", - ], - ids=[ - "Without trailing slash (HTTP)", - "With trailing slash (HTTP)", - "Without trailing slash (HTTPS)", - "With trailing slash (HTTPS)", - ], -) [email protected]("superset.tasks.cache.request.Request") [email protected]("superset.tasks.cache.request.urlopen") -def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context): - from superset.tasks.utils import fetch_csrf_token - - mock_request = mock.MagicMock() - mock_request_cls.return_value = mock_request - - mock_response = mock.MagicMock() - mock_urlopen.return_value.__enter__.return_value = mock_response - - mock_response.status = 200 - mock_response.read.return_value = b'{"result": "csrf_token"}' - mock_response.headers.get_all.return_value = [ - "session=new_session_cookie", - "async-token=websocket_cookie", - ] - - app.config["WEBDRIVER_BASEURL"] = base_url - headers = {"Cookie": "original_session_cookie"} - - result_headers = fetch_csrf_token(headers) - - expected_url = ( - f"{base_url}/api/v1/security/csrf_token/" - if not base_url.endswith("/") - else f"{base_url}api/v1/security/csrf_token/" - ) - - mock_request_cls.assert_called_with( - expected_url, - headers=headers, - method="GET", - ) - - assert result_headers["X-CSRF-Token"] == "csrf_token" - assert result_headers["Cookie"] == "session=new_session_cookie" # Updated assertion - # assert the same Request object is used - mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY) diff --git a/tests/integration_tests/thumbnails_tests.py b/tests/integration_tests/thumbnails_tests.py index e35cd242601..81596bc186c 100644 --- a/tests/integration_tests/thumbnails_tests.py +++ b/tests/integration_tests/thumbnails_tests.py @@ -147,31 +147,31 @@ class TestWebDriverSelenium(SupersetTestCase): def test_screenshot_selenium_headstart( self, mock_sleep, mock_webdriver, mock_webdriver_wait ): - webdriver = WebDriverSelenium("firefox") user = security_manager.get_user_by_username(ADMIN_USERNAME) + webdriver = WebDriverSelenium("firefox", user=user) url = get_url_path("Superset.slice", slice_id=1, standalone="true") app.config["SCREENSHOT_SELENIUM_HEADSTART"] = 5 - webdriver.get_screenshot(url, "chart-container", user=user) + webdriver.get_screenshot(url, "chart-container") assert mock_sleep.call_args_list[0] == call(5) @patch("superset.utils.webdriver.WebDriverWait") @patch("superset.utils.webdriver.firefox") def test_screenshot_selenium_locate_wait(self, mock_webdriver, mock_webdriver_wait): app.config["SCREENSHOT_LOCATE_WAIT"] = 15 - webdriver = WebDriverSelenium("firefox") user = security_manager.get_user_by_username(ADMIN_USERNAME) + webdriver = WebDriverSelenium("firefox", user=user) url = get_url_path("Superset.slice", slice_id=1, standalone="true") - webdriver.get_screenshot(url, "chart-container", user=user) + webdriver.get_screenshot(url, "chart-container") assert mock_webdriver_wait.call_args_list[0] == call(ANY, 15) @patch("superset.utils.webdriver.WebDriverWait") @patch("superset.utils.webdriver.firefox") def test_screenshot_selenium_load_wait(self, mock_webdriver, mock_webdriver_wait): app.config["SCREENSHOT_LOAD_WAIT"] = 15 - webdriver = WebDriverSelenium("firefox") user = security_manager.get_user_by_username(ADMIN_USERNAME) + webdriver = WebDriverSelenium("firefox", user=user) url = get_url_path("Superset.slice", slice_id=1, standalone="true") - webdriver.get_screenshot(url, "chart-container", user=user) + webdriver.get_screenshot(url, "chart-container") assert mock_webdriver_wait.call_args_list[2] == call(ANY, 15) @patch("superset.utils.webdriver.WebDriverWait") @@ -180,11 +180,11 @@ class TestWebDriverSelenium(SupersetTestCase): def test_screenshot_selenium_animation_wait( self, mock_sleep, mock_webdriver, mock_webdriver_wait ): - webdriver = WebDriverSelenium("firefox") user = security_manager.get_user_by_username(ADMIN_USERNAME) + webdriver = WebDriverSelenium("firefox", user=user) url = get_url_path("Superset.slice", slice_id=1, standalone="true") app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] = 4 - webdriver.get_screenshot(url, "chart-container", user=user) + webdriver.get_screenshot(url, "chart-container") assert mock_sleep.call_args_list[1] == call(4)
