This is an automated email from the ASF dual-hosted git repository.

arivero pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new c54b21ef988 fix(mcp): add eager loading to get_info tools to prevent 
N+1 query timeouts (#38129)
c54b21ef988 is described below

commit c54b21ef988a606012f24d65dde7ae720f64ba1d
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Wed Feb 25 11:28:58 2026 -0500

    fix(mcp): add eager loading to get_info tools to prevent N+1 query timeouts 
(#38129)
    
    Co-authored-by: Claude Opus 4.6 <[email protected]>
---
 superset/daos/base.py                              | 11 +++++++++-
 superset/daos/database.py                          |  6 +++++-
 superset/mcp_service/chart/tool/get_chart_info.py  |  9 ++++++++
 .../dashboard/tool/get_dashboard_info.py           | 15 ++++++++++++++
 .../mcp_service/dataset/tool/get_dataset_info.py   | 12 +++++++++++
 superset/mcp_service/mcp_core.py                   | 24 ++++++++++++++--------
 .../mcp_service/system/tool/test_mcp_core.py       |  4 ++--
 7 files changed, 68 insertions(+), 13 deletions(-)

diff --git a/superset/daos/base.py b/superset/daos/base.py
index 4897f796cf8..136184304b5 100644
--- a/superset/daos/base.py
+++ b/superset/daos/base.py
@@ -248,6 +248,7 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
         column_name: str,
         value: str | int,
         skip_base_filter: bool = False,
+        query_options: list[Any] | None = None,
     ) -> T | None:
         """
         Private method to find a model by any column value.
@@ -256,6 +257,8 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
             column_name: Name of the column to search by
             value: Value to search for
             skip_base_filter: Whether to skip base filtering
+            query_options: SQLAlchemy query options (e.g., joinedload,
+                subqueryload) to apply to the query for eager loading
 
         Returns:
             Model instance or None if not found
@@ -263,6 +266,9 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
         query = db.session.query(cls.model_cls)
         query = cls._apply_base_filter(query, skip_base_filter)
 
+        if query_options:
+            query = query.options(*query_options)
+
         if not hasattr(cls.model_cls, column_name):
             return None
 
@@ -283,6 +289,7 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
         model_id: str | int,
         skip_base_filter: bool = False,
         id_column: str | None = None,
+        query_options: list[Any] | None = None,
     ) -> T | None:
         """
         Find a model by ID using specified or default ID column.
@@ -291,12 +298,14 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
             model_id: ID value to search for
             skip_base_filter: Whether to skip base filtering
             id_column: Column name to use (defaults to cls.id_column_name)
+            query_options: SQLAlchemy query options (e.g., joinedload,
+                subqueryload) to apply to the query for eager loading
 
         Returns:
             Model instance or None if not found
         """
         column = id_column or cls.id_column_name
-        return cls._find_by_column(column, model_id, skip_base_filter)
+        return cls._find_by_column(column, model_id, skip_base_filter, 
query_options)
 
     @classmethod
     def find_by_ids(
diff --git a/superset/daos/database.py b/superset/daos/database.py
index cd1bc3d51b3..5b1b33b3839 100644
--- a/superset/daos/database.py
+++ b/superset/daos/database.py
@@ -70,11 +70,15 @@ class DatabaseDAO(BaseDAO[Database]):
         model_id: str | int,
         skip_base_filter: bool = False,
         id_column: str | None = None,
+        query_options: list[Any] | None = None,
     ) -> Database | None:
         """
         Find a database by id, eagerly loading the SSH tunnel relationship.
         """
-        query = 
db.session.query(cls.model_cls).options(joinedload(Database.ssh_tunnel))
+        all_options = [joinedload(Database.ssh_tunnel)]
+        if query_options:
+            all_options.extend(query_options)
+        query = db.session.query(cls.model_cls).options(*all_options)
         query = cls._apply_base_filter(query, skip_base_filter)
 
         column_name = id_column or cls.id_column_name
diff --git a/superset/mcp_service/chart/tool/get_chart_info.py 
b/superset/mcp_service/chart/tool/get_chart_info.py
index de9dc314e6f..d3d92967408 100644
--- a/superset/mcp_service/chart/tool/get_chart_info.py
+++ b/superset/mcp_service/chart/tool/get_chart_info.py
@@ -22,6 +22,7 @@ MCP tool: get_chart_info
 import logging
 
 from fastmcp import Context
+from sqlalchemy.orm import subqueryload
 from superset_core.mcp import tool
 
 from superset.commands.exceptions import CommandException
@@ -93,6 +94,7 @@ async def get_chart_info(
     Returns chart details including name, type, and URL.
     """
     from superset.daos.chart import ChartDAO
+    from superset.models.slice import Slice
     from superset.utils import json as utils_json
 
     await ctx.info(
@@ -100,6 +102,12 @@ async def get_chart_info(
         % (request.identifier, request.form_data_key)
     )
 
+    # Eager load owners and tags to avoid N+1 queries during serialization
+    eager_options = [
+        subqueryload(Slice.owners),
+        subqueryload(Slice.tags),
+    ]
+
     with event_logger.log_context(action="mcp.get_chart_info.lookup"):
         tool = ModelGetInfoCore(
             dao_class=ChartDAO,
@@ -108,6 +116,7 @@ async def get_chart_info(
             serializer=serialize_chart_object,
             supports_slug=False,  # Charts don't have slugs
             logger=logger,
+            query_options=eager_options,
         )
 
         result = tool.run_tool(request.identifier)
diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_info.py 
b/superset/mcp_service/dashboard/tool/get_dashboard_info.py
index ebca60a7bb7..31646db3753 100644
--- a/superset/mcp_service/dashboard/tool/get_dashboard_info.py
+++ b/superset/mcp_service/dashboard/tool/get_dashboard_info.py
@@ -26,6 +26,7 @@ import logging
 from datetime import datetime, timezone
 
 from fastmcp import Context
+from sqlalchemy.orm import subqueryload
 from superset_core.mcp import tool
 
 from superset.dashboards.permalink.exceptions import 
DashboardPermalinkGetFailedError
@@ -98,6 +99,19 @@ async def get_dashboard_info(
 
     try:
         from superset.daos.dashboard import DashboardDAO
+        from superset.models.dashboard import Dashboard
+        from superset.models.slice import Slice
+
+        # Eager load slices (charts), owners, tags, and roles to avoid N+1
+        # queries. Also eager load owners/tags on each slice since the
+        # dashboard serializer calls serialize_chart_object for every chart.
+        eager_options = [
+            subqueryload(Dashboard.slices).subqueryload(Slice.owners),
+            subqueryload(Dashboard.slices).subqueryload(Slice.tags),
+            subqueryload(Dashboard.owners),
+            subqueryload(Dashboard.tags),
+            subqueryload(Dashboard.roles),
+        ]
 
         with event_logger.log_context(action="mcp.get_dashboard_info.lookup"):
             tool = ModelGetInfoCore(
@@ -107,6 +121,7 @@ async def get_dashboard_info(
                 serializer=dashboard_serializer,
                 supports_slug=True,  # Dashboards support slugs
                 logger=logger,
+                query_options=eager_options,
             )
 
             result = tool.run_tool(request.identifier)
diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py 
b/superset/mcp_service/dataset/tool/get_dataset_info.py
index e9e8817d2d9..35c963eb2bd 100644
--- a/superset/mcp_service/dataset/tool/get_dataset_info.py
+++ b/superset/mcp_service/dataset/tool/get_dataset_info.py
@@ -26,6 +26,7 @@ import logging
 from datetime import datetime, timezone
 
 from fastmcp import Context
+from sqlalchemy.orm import joinedload, subqueryload
 from superset_core.mcp import tool
 
 from superset.extensions import event_logger
@@ -82,8 +83,18 @@ async def get_dataset_info(
     )
 
     try:
+        from superset.connectors.sqla.models import SqlaTable
         from superset.daos.dataset import DatasetDAO
 
+        # Eager load columns, metrics, and database to avoid N+1 queries.
+        # Without this, serialize_dataset_object triggers lazy loads for each
+        # relationship, which can time out on datasets with many columns.
+        eager_options = [
+            subqueryload(SqlaTable.columns),
+            subqueryload(SqlaTable.metrics),
+            joinedload(SqlaTable.database),
+        ]
+
         with event_logger.log_context(action="mcp.get_dataset_info.lookup"):
             tool = ModelGetInfoCore(
                 dao_class=DatasetDAO,
@@ -92,6 +103,7 @@ async def get_dataset_info(
                 serializer=serialize_dataset_object,
                 supports_slug=False,  # Datasets don't have slugs
                 logger=logger,
+                query_options=eager_options,
             )
 
             result = tool.run_tool(request.identifier)
diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py
index ef0612553c1..3ac0f008ca4 100644
--- a/superset/mcp_service/mcp_core.py
+++ b/superset/mcp_service/mcp_core.py
@@ -240,6 +240,7 @@ class ModelGetInfoCore(BaseCore):
         serializer: Callable[[T], BaseModel],
         supports_slug: bool = False,
         logger: logging.Logger | None = None,
+        query_options: list[Any] | None = None,
     ) -> None:
         super().__init__(logger)
         self.dao_class = dao_class
@@ -247,29 +248,35 @@ class ModelGetInfoCore(BaseCore):
         self.error_schema = error_schema
         self.serializer = serializer
         self.supports_slug = supports_slug
+        self.query_options = query_options or []
 
     def _find_object(self, identifier: int | str) -> Any:
         """Find object by identifier using appropriate method."""
+        opts = self.query_options or None
         # If it's an integer or string that can be converted to int, use 
find_by_id
         if isinstance(identifier, int):
-            return self.dao_class.find_by_id(identifier)
+            return self.dao_class.find_by_id(identifier, query_options=opts)
 
         try:
             # Try to convert string to int
             id_val = int(identifier)
-            return self.dao_class.find_by_id(id_val)
+            return self.dao_class.find_by_id(id_val, query_options=opts)
         except ValueError:
             pass
 
         # Check if it's a UUID
         if _is_uuid(identifier):
             # Use the new flexible find_by_id with uuid column
-            return self.dao_class.find_by_id(identifier, id_column="uuid")
+            return self.dao_class.find_by_id(
+                identifier, id_column="uuid", query_options=opts
+            )
 
         # For dashboards, also check slug
         if self.supports_slug:
             # Try to find by slug using the new flexible method
-            result = self.dao_class.find_by_id(identifier, id_column="slug")
+            result = self.dao_class.find_by_id(
+                identifier, id_column="slug", query_options=opts
+            )
             if result:
                 return result
 
@@ -278,11 +285,10 @@ class ModelGetInfoCore(BaseCore):
             from superset.models.dashboard import id_or_slug_filter
 
             model_class = self.dao_class.model_cls
-            return (
-                db.session.query(model_class)
-                .filter(id_or_slug_filter(identifier))
-                .one_or_none()
-            )
+            query = 
db.session.query(model_class).filter(id_or_slug_filter(identifier))
+            if opts:
+                query = query.options(*opts)
+            return query.one_or_none()
 
         # If we get here, it's an invalid identifier
         return None
diff --git a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py 
b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py
index ad3ece3d498..aefdd8c8433 100644
--- a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py
+++ b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py
@@ -69,7 +69,7 @@ class DummyDAO:
         return [SimpleNamespace(id=1, name="foo"), SimpleNamespace(id=2, 
name="bar")], 2
 
     @classmethod
-    def find_by_id(cls, id):
+    def find_by_id(cls, id, **kwargs):
         if id == 1:
             return SimpleNamespace(id=1, name="foo")
         return None
@@ -196,7 +196,7 @@ def test_model_get_info_tool_not_found():
 def test_model_get_info_tool_exception():
     class FailingDAO:
         @classmethod
-        def find_by_id(cls, id):
+        def find_by_id(cls, id, **kwargs):
             raise Exception("fail")
 
     tool = ModelGetInfoCore(

Reply via email to