This is an automated email from the ASF dual-hosted git repository.
arivero pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new c54b21ef988 fix(mcp): add eager loading to get_info tools to prevent
N+1 query timeouts (#38129)
c54b21ef988 is described below
commit c54b21ef988a606012f24d65dde7ae720f64ba1d
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Wed Feb 25 11:28:58 2026 -0500
fix(mcp): add eager loading to get_info tools to prevent N+1 query timeouts
(#38129)
Co-authored-by: Claude Opus 4.6 <[email protected]>
---
superset/daos/base.py | 11 +++++++++-
superset/daos/database.py | 6 +++++-
superset/mcp_service/chart/tool/get_chart_info.py | 9 ++++++++
.../dashboard/tool/get_dashboard_info.py | 15 ++++++++++++++
.../mcp_service/dataset/tool/get_dataset_info.py | 12 +++++++++++
superset/mcp_service/mcp_core.py | 24 ++++++++++++++--------
.../mcp_service/system/tool/test_mcp_core.py | 4 ++--
7 files changed, 68 insertions(+), 13 deletions(-)
diff --git a/superset/daos/base.py b/superset/daos/base.py
index 4897f796cf8..136184304b5 100644
--- a/superset/daos/base.py
+++ b/superset/daos/base.py
@@ -248,6 +248,7 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
column_name: str,
value: str | int,
skip_base_filter: bool = False,
+ query_options: list[Any] | None = None,
) -> T | None:
"""
Private method to find a model by any column value.
@@ -256,6 +257,8 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
column_name: Name of the column to search by
value: Value to search for
skip_base_filter: Whether to skip base filtering
+ query_options: SQLAlchemy query options (e.g., joinedload,
+ subqueryload) to apply to the query for eager loading
Returns:
Model instance or None if not found
@@ -263,6 +266,9 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
query = db.session.query(cls.model_cls)
query = cls._apply_base_filter(query, skip_base_filter)
+ if query_options:
+ query = query.options(*query_options)
+
if not hasattr(cls.model_cls, column_name):
return None
@@ -283,6 +289,7 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
model_id: str | int,
skip_base_filter: bool = False,
id_column: str | None = None,
+ query_options: list[Any] | None = None,
) -> T | None:
"""
Find a model by ID using specified or default ID column.
@@ -291,12 +298,14 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]):
model_id: ID value to search for
skip_base_filter: Whether to skip base filtering
id_column: Column name to use (defaults to cls.id_column_name)
+ query_options: SQLAlchemy query options (e.g., joinedload,
+ subqueryload) to apply to the query for eager loading
Returns:
Model instance or None if not found
"""
column = id_column or cls.id_column_name
- return cls._find_by_column(column, model_id, skip_base_filter)
+ return cls._find_by_column(column, model_id, skip_base_filter,
query_options)
@classmethod
def find_by_ids(
diff --git a/superset/daos/database.py b/superset/daos/database.py
index cd1bc3d51b3..5b1b33b3839 100644
--- a/superset/daos/database.py
+++ b/superset/daos/database.py
@@ -70,11 +70,15 @@ class DatabaseDAO(BaseDAO[Database]):
model_id: str | int,
skip_base_filter: bool = False,
id_column: str | None = None,
+ query_options: list[Any] | None = None,
) -> Database | None:
"""
Find a database by id, eagerly loading the SSH tunnel relationship.
"""
- query =
db.session.query(cls.model_cls).options(joinedload(Database.ssh_tunnel))
+ all_options = [joinedload(Database.ssh_tunnel)]
+ if query_options:
+ all_options.extend(query_options)
+ query = db.session.query(cls.model_cls).options(*all_options)
query = cls._apply_base_filter(query, skip_base_filter)
column_name = id_column or cls.id_column_name
diff --git a/superset/mcp_service/chart/tool/get_chart_info.py
b/superset/mcp_service/chart/tool/get_chart_info.py
index de9dc314e6f..d3d92967408 100644
--- a/superset/mcp_service/chart/tool/get_chart_info.py
+++ b/superset/mcp_service/chart/tool/get_chart_info.py
@@ -22,6 +22,7 @@ MCP tool: get_chart_info
import logging
from fastmcp import Context
+from sqlalchemy.orm import subqueryload
from superset_core.mcp import tool
from superset.commands.exceptions import CommandException
@@ -93,6 +94,7 @@ async def get_chart_info(
Returns chart details including name, type, and URL.
"""
from superset.daos.chart import ChartDAO
+ from superset.models.slice import Slice
from superset.utils import json as utils_json
await ctx.info(
@@ -100,6 +102,12 @@ async def get_chart_info(
% (request.identifier, request.form_data_key)
)
+ # Eager load owners and tags to avoid N+1 queries during serialization
+ eager_options = [
+ subqueryload(Slice.owners),
+ subqueryload(Slice.tags),
+ ]
+
with event_logger.log_context(action="mcp.get_chart_info.lookup"):
tool = ModelGetInfoCore(
dao_class=ChartDAO,
@@ -108,6 +116,7 @@ async def get_chart_info(
serializer=serialize_chart_object,
supports_slug=False, # Charts don't have slugs
logger=logger,
+ query_options=eager_options,
)
result = tool.run_tool(request.identifier)
diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_info.py
b/superset/mcp_service/dashboard/tool/get_dashboard_info.py
index ebca60a7bb7..31646db3753 100644
--- a/superset/mcp_service/dashboard/tool/get_dashboard_info.py
+++ b/superset/mcp_service/dashboard/tool/get_dashboard_info.py
@@ -26,6 +26,7 @@ import logging
from datetime import datetime, timezone
from fastmcp import Context
+from sqlalchemy.orm import subqueryload
from superset_core.mcp import tool
from superset.dashboards.permalink.exceptions import
DashboardPermalinkGetFailedError
@@ -98,6 +99,19 @@ async def get_dashboard_info(
try:
from superset.daos.dashboard import DashboardDAO
+ from superset.models.dashboard import Dashboard
+ from superset.models.slice import Slice
+
+ # Eager load slices (charts), owners, tags, and roles to avoid N+1
+ # queries. Also eager load owners/tags on each slice since the
+ # dashboard serializer calls serialize_chart_object for every chart.
+ eager_options = [
+ subqueryload(Dashboard.slices).subqueryload(Slice.owners),
+ subqueryload(Dashboard.slices).subqueryload(Slice.tags),
+ subqueryload(Dashboard.owners),
+ subqueryload(Dashboard.tags),
+ subqueryload(Dashboard.roles),
+ ]
with event_logger.log_context(action="mcp.get_dashboard_info.lookup"):
tool = ModelGetInfoCore(
@@ -107,6 +121,7 @@ async def get_dashboard_info(
serializer=dashboard_serializer,
supports_slug=True, # Dashboards support slugs
logger=logger,
+ query_options=eager_options,
)
result = tool.run_tool(request.identifier)
diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py
b/superset/mcp_service/dataset/tool/get_dataset_info.py
index e9e8817d2d9..35c963eb2bd 100644
--- a/superset/mcp_service/dataset/tool/get_dataset_info.py
+++ b/superset/mcp_service/dataset/tool/get_dataset_info.py
@@ -26,6 +26,7 @@ import logging
from datetime import datetime, timezone
from fastmcp import Context
+from sqlalchemy.orm import joinedload, subqueryload
from superset_core.mcp import tool
from superset.extensions import event_logger
@@ -82,8 +83,18 @@ async def get_dataset_info(
)
try:
+ from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
+ # Eager load columns, metrics, and database to avoid N+1 queries.
+ # Without this, serialize_dataset_object triggers lazy loads for each
+ # relationship, which can time out on datasets with many columns.
+ eager_options = [
+ subqueryload(SqlaTable.columns),
+ subqueryload(SqlaTable.metrics),
+ joinedload(SqlaTable.database),
+ ]
+
with event_logger.log_context(action="mcp.get_dataset_info.lookup"):
tool = ModelGetInfoCore(
dao_class=DatasetDAO,
@@ -92,6 +103,7 @@ async def get_dataset_info(
serializer=serialize_dataset_object,
supports_slug=False, # Datasets don't have slugs
logger=logger,
+ query_options=eager_options,
)
result = tool.run_tool(request.identifier)
diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py
index ef0612553c1..3ac0f008ca4 100644
--- a/superset/mcp_service/mcp_core.py
+++ b/superset/mcp_service/mcp_core.py
@@ -240,6 +240,7 @@ class ModelGetInfoCore(BaseCore):
serializer: Callable[[T], BaseModel],
supports_slug: bool = False,
logger: logging.Logger | None = None,
+ query_options: list[Any] | None = None,
) -> None:
super().__init__(logger)
self.dao_class = dao_class
@@ -247,29 +248,35 @@ class ModelGetInfoCore(BaseCore):
self.error_schema = error_schema
self.serializer = serializer
self.supports_slug = supports_slug
+ self.query_options = query_options or []
def _find_object(self, identifier: int | str) -> Any:
"""Find object by identifier using appropriate method."""
+ opts = self.query_options or None
# If it's an integer or string that can be converted to int, use
find_by_id
if isinstance(identifier, int):
- return self.dao_class.find_by_id(identifier)
+ return self.dao_class.find_by_id(identifier, query_options=opts)
try:
# Try to convert string to int
id_val = int(identifier)
- return self.dao_class.find_by_id(id_val)
+ return self.dao_class.find_by_id(id_val, query_options=opts)
except ValueError:
pass
# Check if it's a UUID
if _is_uuid(identifier):
# Use the new flexible find_by_id with uuid column
- return self.dao_class.find_by_id(identifier, id_column="uuid")
+ return self.dao_class.find_by_id(
+ identifier, id_column="uuid", query_options=opts
+ )
# For dashboards, also check slug
if self.supports_slug:
# Try to find by slug using the new flexible method
- result = self.dao_class.find_by_id(identifier, id_column="slug")
+ result = self.dao_class.find_by_id(
+ identifier, id_column="slug", query_options=opts
+ )
if result:
return result
@@ -278,11 +285,10 @@ class ModelGetInfoCore(BaseCore):
from superset.models.dashboard import id_or_slug_filter
model_class = self.dao_class.model_cls
- return (
- db.session.query(model_class)
- .filter(id_or_slug_filter(identifier))
- .one_or_none()
- )
+ query =
db.session.query(model_class).filter(id_or_slug_filter(identifier))
+ if opts:
+ query = query.options(*opts)
+ return query.one_or_none()
# If we get here, it's an invalid identifier
return None
diff --git a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py
b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py
index ad3ece3d498..aefdd8c8433 100644
--- a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py
+++ b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py
@@ -69,7 +69,7 @@ class DummyDAO:
return [SimpleNamespace(id=1, name="foo"), SimpleNamespace(id=2,
name="bar")], 2
@classmethod
- def find_by_id(cls, id):
+ def find_by_id(cls, id, **kwargs):
if id == 1:
return SimpleNamespace(id=1, name="foo")
return None
@@ -196,7 +196,7 @@ def test_model_get_info_tool_not_found():
def test_model_get_info_tool_exception():
class FailingDAO:
@classmethod
- def find_by_id(cls, id):
+ def find_by_id(cls, id, **kwargs):
raise Exception("fail")
tool = ModelGetInfoCore(