This is an automated email from the ASF dual-hosted git repository.
kgabryje pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new a1312a86e88 fix(mcp): normalize column names to fix time series filter
prompt issue (#37187)
a1312a86e88 is described below
commit a1312a86e88fb7f9cf6d9bff1e3627057f33abbd
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Wed Feb 25 09:27:53 2026 -0500
fix(mcp): normalize column names to fix time series filter prompt issue
(#37187)
Co-authored-by: Claude Opus 4.5 <[email protected]>
---
.../chart/validation/dataset_validator.py | 153 ++++-
superset/mcp_service/chart/validation/pipeline.py | 86 ++-
.../explore/tool/generate_explore_link.py | 15 +-
.../validation/test_column_name_normalization.py | 681 +++++++++++++++++++++
.../explore/tool/test_generate_explore_link.py | 149 +++++
5 files changed, 1073 insertions(+), 11 deletions(-)
diff --git a/superset/mcp_service/chart/validation/dataset_validator.py
b/superset/mcp_service/chart/validation/dataset_validator.py
index b03d0ffe9c0..c50fccc6938 100644
--- a/superset/mcp_service/chart/validation/dataset_validator.py
+++ b/superset/mcp_service/chart/validation/dataset_validator.py
@@ -22,7 +22,7 @@ Validates that referenced columns exist in the dataset schema.
import difflib
import logging
-from typing import Dict, List, Tuple
+from typing import Any, Dict, List, Tuple
from superset.mcp_service.chart.schemas import (
ColumnRef,
@@ -37,13 +37,25 @@ from superset.mcp_service.common.error_schemas import (
logger = logging.getLogger(__name__)
+# Exceptions that can occur during column name normalization.
+# Shared by the validation pipeline and tool-level normalization calls.
+NORMALIZATION_EXCEPTIONS = (
+ ImportError,
+ AttributeError,
+ KeyError,
+ ValueError,
+ TypeError,
+)
+
class DatasetValidator:
"""Validates chart configuration against dataset schema."""
@staticmethod
def validate_against_dataset(
- config: TableChartConfig | XYChartConfig, dataset_id: int | str
+ config: TableChartConfig | XYChartConfig,
+ dataset_id: int | str,
+ dataset_context: DatasetContext | None = None,
) -> Tuple[bool, ChartGenerationError | None]:
"""
Validate chart configuration against dataset schema.
@@ -51,12 +63,15 @@ class DatasetValidator:
Args:
config: Chart configuration to validate
dataset_id: Dataset ID to validate against
+ dataset_context: Pre-fetched dataset context to avoid duplicate
+ DB queries. If None, fetches from the database.
Returns:
Tuple of (is_valid, error)
"""
- # Get dataset context
- dataset_context = DatasetValidator._get_dataset_context(dataset_id)
+ # Get dataset context (reuse if provided)
+ if dataset_context is None:
+ dataset_context = DatasetValidator._get_dataset_context(dataset_id)
if not dataset_context:
from superset.mcp_service.utils.error_builder import (
ChartErrorBuilder,
@@ -198,6 +213,136 @@ class DatasetValidator:
return False
+ @staticmethod
+ def _get_canonical_column_name(
+ column_name: str, dataset_context: DatasetContext
+ ) -> str:
+ """
+ Get the canonical column name from the dataset.
+
+ Performs case-insensitive matching and returns the actual column name
+ as stored in the dataset. This ensures column names in form_data match
+ exactly with what the frontend expects.
+
+ Args:
+ column_name: The column name to normalize
+ dataset_context: Dataset context with column information
+
+ Returns:
+ The canonical column name from the dataset, or the original name
+ if no match is found.
+ """
+ column_lower = column_name.lower()
+
+ # Check regular columns first
+ for col in dataset_context.available_columns:
+ if col["name"].lower() == column_lower:
+ return col["name"]
+
+ # Check metrics
+ for metric in dataset_context.available_metrics:
+ if metric["name"].lower() == column_lower:
+ return metric["name"]
+
+ # Return original if not found (validation should catch this case)
+ return column_name
+
+ @staticmethod
+ def _normalize_xy_config(
+ config_dict: Dict[str, Any], dataset_context: DatasetContext
+ ) -> None:
+ """Normalize column names in an XY chart config dict in place."""
+ # Normalize x-axis column
+ if "x" in config_dict and config_dict["x"]:
+ config_dict["x"]["name"] =
DatasetValidator._get_canonical_column_name(
+ config_dict["x"]["name"], dataset_context
+ )
+
+ # Normalize y-axis columns
+ if "y" in config_dict and config_dict["y"]:
+ for y_col in config_dict["y"]:
+ y_col["name"] = DatasetValidator._get_canonical_column_name(
+ y_col["name"], dataset_context
+ )
+
+ # Normalize group_by column
+ if "group_by" in config_dict and config_dict["group_by"]:
+ config_dict["group_by"]["name"] = (
+ DatasetValidator._get_canonical_column_name(
+ config_dict["group_by"]["name"], dataset_context
+ )
+ )
+
+ @staticmethod
+ def _normalize_table_config(
+ config_dict: Dict[str, Any], dataset_context: DatasetContext
+ ) -> None:
+ """Normalize column names in a table chart config dict in place."""
+ if "columns" in config_dict and config_dict["columns"]:
+ for col in config_dict["columns"]:
+ col["name"] = DatasetValidator._get_canonical_column_name(
+ col["name"], dataset_context
+ )
+
+ @staticmethod
+ def _normalize_filters(
+ config_dict: Dict[str, Any], dataset_context: DatasetContext
+ ) -> None:
+ """Normalize filter column names in a config dict in place."""
+ if "filters" in config_dict and config_dict["filters"]:
+ for filter_config in config_dict["filters"]:
+ if filter_config and "column" in filter_config:
+ filter_config["column"] = (
+ DatasetValidator._get_canonical_column_name(
+ filter_config["column"], dataset_context
+ )
+ )
+
+ @staticmethod
+ def normalize_column_names(
+ config: TableChartConfig | XYChartConfig,
+ dataset_id: int | str,
+ dataset_context: DatasetContext | None = None,
+ ) -> TableChartConfig | XYChartConfig:
+ """
+ Normalize column names in config to match the canonical dataset column
names.
+
+ This fixes case sensitivity issues where user-provided column names
+ (e.g., 'order_date') don't match exactly with the dataset column names
+ (e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
+ so we need to ensure column names match exactly.
+
+ Args:
+ config: Chart configuration with column references
+ dataset_id: Dataset ID to get canonical column names from
+ dataset_context: Pre-fetched dataset context to avoid duplicate
+ DB queries. If None, fetches from the database.
+
+ Returns:
+ A new config with normalized column names
+ """
+ if dataset_context is None:
+ dataset_context = DatasetValidator._get_dataset_context(dataset_id)
+ if not dataset_context:
+ return config
+
+ # Create a mutable copy of the config
+ config_dict = config.model_dump()
+
+ # Normalize based on config type
+ if isinstance(config, XYChartConfig):
+ DatasetValidator._normalize_xy_config(config_dict, dataset_context)
+ elif isinstance(config, TableChartConfig):
+ DatasetValidator._normalize_table_config(config_dict,
dataset_context)
+
+ # Normalize filter columns (common to both config types)
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+
+ # Reconstruct the config with normalized names
+ if isinstance(config, XYChartConfig):
+ return XYChartConfig.model_validate(config_dict)
+ return TableChartConfig.model_validate(config_dict)
+
@staticmethod
def _get_column_suggestions(
column_name: str, dataset_context: DatasetContext, max_suggestions:
int = 3
diff --git a/superset/mcp_service/chart/validation/pipeline.py
b/superset/mcp_service/chart/validation/pipeline.py
index 948f9d2e62d..a0f475ffa0e 100644
--- a/superset/mcp_service/chart/validation/pipeline.py
+++ b/superset/mcp_service/chart/validation/pipeline.py
@@ -27,7 +27,10 @@ from superset.mcp_service.chart.schemas import (
ChartConfig,
GenerateChartRequest,
)
-from superset.mcp_service.common.error_schemas import ChartGenerationError
+from superset.mcp_service.common.error_schemas import (
+ ChartGenerationError,
+ DatasetContext,
+)
logger = logging.getLogger(__name__)
@@ -168,9 +171,14 @@ class ValidationPipeline:
if request is None:
return ValidationResult(is_valid=False, error=error)
- # Layer 2: Dataset validation
+ # Fetch dataset context once and reuse across validation layers
+ dataset_context = ValidationPipeline._get_dataset_context(
+ request.dataset_id
+ )
+
+ # Layer 2: Dataset validation (reuses context)
is_valid, error = ValidationPipeline._validate_dataset(
- request.config, request.dataset_id
+ request.config, request.dataset_id, dataset_context
)
if not is_valid:
return ValidationResult(is_valid=False, request=request,
error=error)
@@ -181,8 +189,15 @@ class ValidationPipeline:
)
# Runtime validation always returns True now, warnings are
informational
+ # Layer 4: Column name normalization (reuses context)
+ normalized_request = ValidationPipeline._normalize_column_names(
+ request, dataset_context
+ )
+
return ValidationResult(
- is_valid=True, request=request, warnings=warnings_metadata
+ is_valid=True,
+ request=normalized_request,
+ warnings=warnings_metadata,
)
except Exception as e:
@@ -201,15 +216,32 @@ class ValidationPipeline:
)
return ValidationResult(is_valid=False, error=error)
+ @staticmethod
+ def _get_dataset_context(
+ dataset_id: int | str,
+ ) -> DatasetContext | None:
+ """Fetch dataset context once to reuse across validation layers."""
+ try:
+ from .dataset_validator import DatasetValidator
+
+ return DatasetValidator._get_dataset_context(dataset_id)
+ except ImportError:
+ logger.warning("Dataset validator not available, skipping context
fetch")
+ return None
+
@staticmethod
def _validate_dataset(
- config: ChartConfig, dataset_id: int | str
+ config: ChartConfig,
+ dataset_id: int | str,
+ dataset_context: DatasetContext | None = None,
) -> Tuple[bool, ChartGenerationError | None]:
"""Validate configuration against dataset schema."""
try:
from .dataset_validator import DatasetValidator
- return DatasetValidator.validate_against_dataset(config,
dataset_id)
+ return DatasetValidator.validate_against_dataset(
+ config, dataset_id, dataset_context=dataset_context
+ )
except ImportError:
# Skip if dataset validator not available
logger.warning(
@@ -248,6 +280,48 @@ class ValidationPipeline:
# Don't fail on runtime validation errors
return True, None
+ @staticmethod
+ def _normalize_column_names(
+ request: GenerateChartRequest,
+ dataset_context: DatasetContext | None = None,
+ ) -> GenerateChartRequest:
+ """
+ Normalize column names in the request to match canonical dataset names.
+
+ This fixes case sensitivity issues where user-provided column names
+ don't match exactly with the dataset column names. For example,
+ if a user provides 'order_date' but the dataset has 'OrderDate',
+ this method will normalize it to 'OrderDate'.
+
+ Args:
+ request: The validated chart generation request
+ dataset_context: Pre-fetched dataset context to avoid duplicate
+ DB queries. If None, fetches from the database.
+
+ Returns:
+ A new request with normalized column names
+ """
+ try:
+ from .dataset_validator import DatasetValidator
+
+ normalized_config = DatasetValidator.normalize_column_names(
+ request.config,
+ request.dataset_id,
+ dataset_context=dataset_context,
+ )
+
+ # Create a new request with the normalized config
+ request_dict = request.model_dump()
+ request_dict["config"] = normalized_config.model_dump()
+
+ return GenerateChartRequest.model_validate(request_dict)
+
+ except (ImportError, AttributeError, KeyError, ValueError, TypeError)
as e:
+ # If normalization fails, return the original request
+ # Validation has already passed, so this is a non-critical failure
+ logger.warning("Column name normalization failed: %s", e)
+ return request
+
@staticmethod
def validate_filters(
filters: List[Any],
diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py
b/superset/mcp_service/explore/tool/generate_explore_link.py
index 3048a538c6c..dca721caa8b 100644
--- a/superset/mcp_service/explore/tool/generate_explore_link.py
+++ b/superset/mcp_service/explore/tool/generate_explore_link.py
@@ -93,9 +93,22 @@ async def generate_explore_link(
try:
await ctx.report_progress(1, 3, "Converting configuration to form
data")
with
event_logger.log_context(action="mcp.generate_explore_link.form_data"):
+ # Normalize column names to match canonical dataset column names
+ # This fixes case sensitivity issues (e.g., 'order_date' vs
'OrderDate')
+ try:
+ from superset.mcp_service.chart.validation.dataset_validator
import (
+ DatasetValidator,
+ )
+
+ normalized_config = DatasetValidator.normalize_column_names(
+ request.config, request.dataset_id
+ )
+ except (ImportError, AttributeError, KeyError, ValueError,
TypeError):
+ normalized_config = request.config
+
# Map config to form_data using shared utilities
form_data = map_config_to_form_data(
- request.config, dataset_id=request.dataset_id
+ normalized_config, dataset_id=request.dataset_id
)
# Add datasource to form_data for consistency with generate_chart
diff --git
a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
new file mode 100644
index 00000000000..77fdf64143f
--- /dev/null
+++
b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
@@ -0,0 +1,681 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for column name normalization in the MCP service.
+
+This tests the fix for the issue where time series charts would incorrectly
+prompt to add the x-axis to filters when the column name case didn't match
+exactly (e.g., 'order_date' vs 'OrderDate').
+"""
+
+from typing import Any, Dict
+from unittest.mock import patch
+
+import pytest
+
+from superset.mcp_service.chart.schemas import (
+ ColumnRef,
+ FilterConfig,
+ TableChartConfig,
+ XYChartConfig,
+)
+from superset.mcp_service.chart.validation.dataset_validator import
DatasetValidator
+from superset.mcp_service.common.error_schemas import DatasetContext
+
+
[email protected]
+def mock_dataset_context() -> DatasetContext:
+ """Create a mock dataset context with mixed-case column names."""
+ return DatasetContext(
+ id=18,
+ table_name="Vehicle Sales",
+ schema="public",
+ database_name="examples",
+ available_columns=[
+ {"name": "OrderDate", "type": "DATE", "is_temporal": True},
+ {"name": "ProductLine", "type": "VARCHAR", "is_temporal": False},
+ {"name": "Sales", "type": "DECIMAL", "is_numeric": True},
+ {"name": "quantity_ordered", "type": "INTEGER", "is_numeric":
True},
+ ],
+ available_metrics=[
+ {"name": "TotalRevenue", "expression": "SUM(Sales)",
"description": None},
+ ],
+ )
+
+
+class TestGetCanonicalColumnName:
+ """Test _get_canonical_column_name static method."""
+
+ def test_exact_match_returns_same_name(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that exact match returns the same column name."""
+ result = DatasetValidator._get_canonical_column_name(
+ "OrderDate", mock_dataset_context
+ )
+ assert result == "OrderDate"
+
+ def test_lowercase_returns_canonical_name(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that lowercase input returns the canonical (dataset) column
name."""
+ result = DatasetValidator._get_canonical_column_name(
+ "orderdate", mock_dataset_context
+ )
+ assert result == "OrderDate"
+
+ def test_snake_case_returns_canonical_name(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that snake_case input returns the canonical column name."""
+ # 'order_date' won't match 'OrderDate' directly, but would match if
+ # the dataset had 'order_date'. This test verifies case-insensitive
matching.
+ result = DatasetValidator._get_canonical_column_name(
+ "productline", mock_dataset_context
+ )
+ assert result == "ProductLine"
+
+ def test_uppercase_returns_canonical_name(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that uppercase input returns the canonical column name."""
+ result = DatasetValidator._get_canonical_column_name(
+ "SALES", mock_dataset_context
+ )
+ assert result == "Sales"
+
+ def test_metric_name_normalization(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that metric names are also normalized."""
+ result = DatasetValidator._get_canonical_column_name(
+ "totalrevenue", mock_dataset_context
+ )
+ assert result == "TotalRevenue"
+
+ def test_unknown_column_returns_original(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that unknown columns return the original name."""
+ result = DatasetValidator._get_canonical_column_name(
+ "unknown_column", mock_dataset_context
+ )
+ assert result == "unknown_column"
+
+
+class TestNormalizeXYConfig:
+ """Test _normalize_xy_config static method."""
+
+ def test_normalize_x_axis_column(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that x-axis column name is normalized."""
+ config_dict: Dict[str, Any] = {
+ "chart_type": "xy",
+ "x": {"name": "orderdate"},
+ "y": [{"name": "Sales", "aggregate": "SUM"}],
+ "kind": "line",
+ }
+
+ DatasetValidator._normalize_xy_config(config_dict,
mock_dataset_context)
+
+ assert config_dict["x"]["name"] == "OrderDate"
+
+ def test_normalize_y_axis_columns(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that y-axis column names are normalized."""
+ config_dict: Dict[str, Any] = {
+ "chart_type": "xy",
+ "x": {"name": "OrderDate"},
+ "y": [
+ {"name": "sales", "aggregate": "SUM"},
+ {"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
+ ],
+ "kind": "bar",
+ }
+
+ DatasetValidator._normalize_xy_config(config_dict,
mock_dataset_context)
+
+ assert config_dict["y"][0]["name"] == "Sales"
+ assert config_dict["y"][1]["name"] == "quantity_ordered"
+
+ def test_normalize_group_by_column(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that group_by column name is normalized."""
+ config_dict: Dict[str, Any] = {
+ "chart_type": "xy",
+ "x": {"name": "OrderDate"},
+ "y": [{"name": "Sales", "aggregate": "SUM"}],
+ "kind": "line",
+ "group_by": {"name": "productline"},
+ }
+
+ DatasetValidator._normalize_xy_config(config_dict,
mock_dataset_context)
+
+ assert config_dict["group_by"]["name"] == "ProductLine"
+
+
+class TestNormalizeTableConfig:
+ """Test _normalize_table_config static method."""
+
+ def test_normalize_table_columns(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that table column names are normalized."""
+ config_dict: Dict[str, Any] = {
+ "chart_type": "table",
+ "columns": [
+ {"name": "orderdate"},
+ {"name": "PRODUCTLINE"},
+ {"name": "sales", "aggregate": "SUM"},
+ ],
+ }
+
+ DatasetValidator._normalize_table_config(config_dict,
mock_dataset_context)
+
+ assert config_dict["columns"][0]["name"] == "OrderDate"
+ assert config_dict["columns"][1]["name"] == "ProductLine"
+ assert config_dict["columns"][2]["name"] == "Sales"
+
+
+class TestNormalizeFilters:
+ """Test _normalize_filters static method."""
+
+ def test_normalize_filter_columns(
+ self, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that filter column names are normalized."""
+ config_dict: Dict[str, Any] = {
+ "filters": [
+ {"column": "productline", "op": "=", "value": "Classic Cars"},
+ {"column": "ORDERDATE", "op": ">", "value": "2023-01-01"},
+ ],
+ }
+
+ DatasetValidator._normalize_filters(config_dict, mock_dataset_context)
+
+ assert config_dict["filters"][0]["column"] == "ProductLine"
+ assert config_dict["filters"][1]["column"] == "OrderDate"
+
+
+class TestNormalizeColumnNames:
+ """Test the main normalize_column_names method."""
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_normalize_xy_chart_config(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test full normalization of XY chart config."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"), # lowercase - should normalize to
OrderDate
+ y=[
+ ColumnRef(name="sales", aggregate="SUM")
+ ], # lowercase - should normalize to Sales
+ kind="line",
+ filters=[FilterConfig(column="productline", op="=", value="Classic
Cars")],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.x.name == "OrderDate"
+ assert normalized.y[0].name == "Sales"
+ assert normalized.filters is not None
+ assert normalized.filters[0].column == "ProductLine"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_normalize_table_chart_config(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test full normalization of table chart config."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = TableChartConfig(
+ chart_type="table",
+ columns=[
+ ColumnRef(name="orderdate"),
+ ColumnRef(name="productline"),
+ ColumnRef(name="sales", aggregate="SUM"),
+ ],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.columns[0].name == "OrderDate"
+ assert normalized.columns[1].name == "ProductLine"
+ assert normalized.columns[2].name == "Sales"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_returns_original_when_dataset_not_found(self, mock_get_context)
-> None:
+ """Test that original config is returned when dataset context is
unavailable."""
+ mock_get_context.return_value = None
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=999)
+
+ # Should return original config unchanged
+ assert normalized.x.name == "orderdate"
+ assert normalized.y[0].name == "sales"
+
+
+class TestTimeSeriesFilterPromptFix:
+ """Test the fix for time series charts incorrectly prompting x-axis
filters."""
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_x_axis_matches_existing_filter_after_normalization(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """
+ Test the core fix: when creating a time series chart with
+ 'order_date' as x-axis, and there's already a filter with
+ 'OrderDate', after normalization they should match.
+
+ This is the exact scenario from the bug report where:
+ - User creates chart with x_axis = 'order_date'
+ - Dataset has column named 'OrderDate'
+ - Existing filter has subject = 'OrderDate'
+ - Without normalization: 'order_date' != 'OrderDate' -> prompt shown
+ - With normalization: 'OrderDate' == 'OrderDate' -> no prompt
+ """
+ mock_get_context.return_value = mock_dataset_context
+
+ # Simulate what the MCP service receives from user
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"), # User provides lowercase
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ # Simulating an existing filter with the canonical name
+ filters=[
+ FilterConfig(column="OrderDate", op=">", value="2023-01-01"),
+ ],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ # After normalization, x.name should match the filter column exactly
+ assert normalized.x.name == "OrderDate"
+ assert normalized.filters is not None
+ assert normalized.filters[0].column == "OrderDate"
+
+ # This equality is what the frontend checks - now they match!
+ assert normalized.x.name == normalized.filters[0].column
+
+
[email protected]
+def uppercase_dataset_context() -> DatasetContext:
+ """Create a mock dataset context with all-uppercase column names (like
flights)."""
+ return DatasetContext(
+ id=24,
+ table_name="flights",
+ schema="public",
+ database_name="examples",
+ available_columns=[
+ {"name": "DEPARTURE_DELAY", "type": "FLOAT", "is_numeric": True},
+ {"name": "ARRIVAL_DELAY", "type": "FLOAT", "is_numeric": True},
+ {"name": "DISTANCE", "type": "BIGINT", "is_numeric": True},
+ {"name": "AIRLINE", "type": "VARCHAR", "is_temporal": False},
+ {"name": "ds", "type": "TIMESTAMP", "is_temporal": True},
+ ],
+ available_metrics=[
+ {"name": "count", "expression": "COUNT(*)", "description": None},
+ ],
+ )
+
+
+class TestNormalizeMultipleYAxisColumns:
+ """Test normalization of multiple y-axis columns."""
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_normalize_multiple_y_columns(
+ self, mock_get_context, uppercase_dataset_context: DatasetContext
+ ) -> None:
+ """Test that all y-axis columns are normalized."""
+ mock_get_context.return_value = uppercase_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ds"),
+ y=[
+ ColumnRef(name="departure_delay", aggregate="AVG"),
+ ColumnRef(name="arrival_delay", aggregate="AVG"),
+ ],
+ kind="area",
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=24)
+
+ assert normalized.y[0].name == "DEPARTURE_DELAY"
+ assert normalized.y[1].name == "ARRIVAL_DELAY"
+
+
+class TestNormalizeUppercaseDataset:
+ """Test normalization against dataset with all-uppercase column names."""
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_lowercase_to_uppercase(
+ self, mock_get_context, uppercase_dataset_context: DatasetContext
+ ) -> None:
+ """Test lowercase input normalizes to uppercase canonical names."""
+ mock_get_context.return_value = uppercase_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ds"),
+ y=[ColumnRef(name="distance", aggregate="AVG")],
+ kind="bar",
+ group_by=ColumnRef(name="airline"),
+ filters=[FilterConfig(column="airline", op="=", value="AA")],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=24)
+
+ assert normalized.x.name == "ds"
+ assert normalized.y[0].name == "DISTANCE"
+ assert normalized.group_by is not None
+ assert normalized.group_by.name == "AIRLINE"
+ assert normalized.filters is not None
+ assert normalized.filters[0].column == "AIRLINE"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_exact_match_preserved(
+ self, mock_get_context, uppercase_dataset_context: DatasetContext
+ ) -> None:
+ """Test that already-correct names are preserved unchanged."""
+ mock_get_context.return_value = uppercase_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ds"),
+ y=[ColumnRef(name="DEPARTURE_DELAY", aggregate="AVG")],
+ kind="line",
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=24)
+
+ assert normalized.x.name == "ds"
+ assert normalized.y[0].name == "DEPARTURE_DELAY"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_metric_normalized_in_y_axis(
+ self, mock_get_context, uppercase_dataset_context: DatasetContext
+ ) -> None:
+ """Test that metric names used in y-axis are normalized."""
+ mock_get_context.return_value = uppercase_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ds"),
+ y=[ColumnRef(name="COUNT", aggregate="SUM")],
+ kind="bar",
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=24)
+
+ # 'COUNT' should normalize to 'count' (the metric name)
+ assert normalized.y[0].name == "count"
+
+
+class TestNormalizeEdgeCases:
+ """Test edge cases for column name normalization."""
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_config_with_no_filters(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test normalization when config has no filters."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.x.name == "OrderDate"
+ assert normalized.y[0].name == "Sales"
+ assert normalized.filters is None
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_config_with_empty_filters(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test normalization when config has empty filters list."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ filters=[],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.x.name == "OrderDate"
+ assert normalized.filters is not None
+ assert len(normalized.filters) == 0
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_config_with_no_group_by(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test normalization when config has no group_by."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="bar",
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.x.name == "OrderDate"
+ assert normalized.group_by is None
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_all_fields_normalized_together(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that x, y, group_by, and filters are all normalized in one
call."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ORDERDATE"),
+ y=[
+ ColumnRef(name="sales", aggregate="SUM"),
+ ColumnRef(name="QUANTITY_ORDERED", aggregate="COUNT"),
+ ],
+ kind="bar",
+ group_by=ColumnRef(name="PRODUCTLINE"),
+ filters=[
+ FilterConfig(column="productline", op="=", value="Classic
Cars"),
+ FilterConfig(column="ORDERDATE", op=">", value="2023-01-01"),
+ ],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.x.name == "OrderDate"
+ assert normalized.y[0].name == "Sales"
+ assert normalized.y[1].name == "quantity_ordered"
+ assert normalized.group_by is not None
+ assert normalized.group_by.name == "ProductLine"
+ assert normalized.filters is not None
+ assert normalized.filters[0].column == "ProductLine"
+ assert normalized.filters[1].column == "OrderDate"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_normalization_is_idempotent(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that normalizing already-normalized config returns same
result."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ filters=[FilterConfig(column="productline", op="=", value="Cars")],
+ )
+
+ first = DatasetValidator.normalize_column_names(config, dataset_id=18)
+ second = DatasetValidator.normalize_column_names(first, dataset_id=18)
+
+ assert first.x.name == second.x.name == "OrderDate"
+ assert first.y[0].name == second.y[0].name == "Sales"
+ assert first.filters is not None
+ assert second.filters is not None
+ assert first.filters[0].column == second.filters[0].column ==
"ProductLine"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_aggregate_preserved_after_normalization(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that aggregate functions are preserved during normalization."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[
+ ColumnRef(name="sales", aggregate="SUM"),
+ ColumnRef(name="QUANTITY_ORDERED", aggregate="AVG"),
+ ],
+ kind="bar",
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.y[0].aggregate == "SUM"
+ assert normalized.y[1].aggregate == "AVG"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_filter_operator_and_value_preserved(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Test that filter op and value are preserved during normalization."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ filters=[
+ FilterConfig(column="ORDERDATE", op=">=", value="2023-01-01"),
+ FilterConfig(column="sales", op=">", value=1000),
+ ],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.filters is not None
+ assert normalized.filters[0].column == "OrderDate"
+ assert normalized.filters[0].op == ">="
+ assert normalized.filters[0].value == "2023-01-01"
+ assert normalized.filters[1].column == "Sales"
+ assert normalized.filters[1].op == ">"
+ assert normalized.filters[1].value == 1000
+
+
+class TestNormalizeXAxisFilterConsistency:
+ """Test that x-axis and filter column names are consistent after
normalization.
+
+ These tests verify the core bug fix: when x-axis and filter reference
+ the same column but with different cases, normalization ensures they match.
+ """
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_both_wrong_case_normalized_to_same(
+ self, mock_get_context, mock_dataset_context: DatasetContext
+ ) -> None:
+ """Both x-axis and filter in wrong case normalize to same canonical
name."""
+ mock_get_context.return_value = mock_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ORDERDATE"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ filters=[FilterConfig(column="orderdate", op=">",
value="2023-01-01")],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=18)
+
+ assert normalized.filters is not None
+ assert normalized.x.name == normalized.filters[0].column == "OrderDate"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_uppercase_dataset_x_filter_match(
+ self, mock_get_context, uppercase_dataset_context: DatasetContext
+ ) -> None:
+ """On uppercase-column dataset, both lowercase refs normalize to
uppercase."""
+ mock_get_context.return_value = uppercase_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ds"),
+ y=[ColumnRef(name="departure_delay", aggregate="AVG")],
+ kind="line",
+ filters=[FilterConfig(column="ds", op=">", value="2015-01-01")],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=24)
+
+ assert normalized.filters is not None
+ assert normalized.x.name == normalized.filters[0].column == "ds"
+
+ @patch.object(DatasetValidator, "_get_dataset_context")
+ def test_group_by_matches_filter_after_normalization(
+ self, mock_get_context, uppercase_dataset_context: DatasetContext
+ ) -> None:
+ """group_by and filter for same column normalize to same canonical
name."""
+ mock_get_context.return_value = uppercase_dataset_context
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="ds"),
+ y=[ColumnRef(name="distance", aggregate="AVG")],
+ kind="bar",
+ group_by=ColumnRef(name="Airline"),
+ filters=[FilterConfig(column="airline", op="=", value="AA")],
+ )
+
+ normalized = DatasetValidator.normalize_column_names(config,
dataset_id=24)
+
+ assert normalized.group_by is not None
+ assert normalized.filters is not None
+ assert normalized.group_by.name == normalized.filters[0].column ==
"AIRLINE"
diff --git
a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py
b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py
index af08834d57b..0a8771e48ba 100644
--- a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py
+++ b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py
@@ -35,6 +35,7 @@ from superset.mcp_service.chart.schemas import (
TableChartConfig,
XYChartConfig,
)
+from superset.mcp_service.common.error_schemas import DatasetContext
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@@ -706,3 +707,151 @@ class TestGenerateExploreLink:
assert result.data["form_data"].get("x_axis") == "date"
# Verify datasource field format: "{dataset_id}__table"
assert result.data["form_data"].get("datasource") == "1__table"
+
+
+class TestGenerateExploreLinkColumnNormalization:
+ """Tests that generate_explore_link normalizes column names.
+
+ This verifies the fix where user-provided column names in wrong case
+ (e.g., 'order_date') are normalized to the canonical dataset name
+ (e.g., 'OrderDate') before being used in form_data.
+ """
+
+ @patch(
+
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
+ )
+ @patch("superset.daos.dataset.DatasetDAO.find_by_id")
+ @patch(
+
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
+ )
+ @pytest.mark.asyncio
+ async def test_xy_chart_x_axis_normalized_in_form_data(
+ self,
+ mock_create_form_data,
+ mock_find_dataset,
+ mock_get_context,
+ mcp_server,
+ ):
+ """x-axis column name in wrong case is normalized in form_data."""
+ mock_create_form_data.return_value = "norm_test_key_1"
+ mock_find_dataset.return_value = _mock_dataset(id=18)
+ mock_get_context.return_value = DatasetContext(
+ id=18,
+ table_name="Vehicle Sales",
+ schema="public",
+ database_name="examples",
+ available_columns=[
+ {"name": "OrderDate", "type": "DATE", "is_temporal": True},
+ {"name": "Sales", "type": "FLOAT", "is_numeric": True},
+ ],
+ available_metrics=[],
+ )
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ )
+ request = GenerateExploreLinkRequest(dataset_id="18", config=config)
+
+ async with Client(mcp_server) as client:
+ result = await client.call_tool(
+ "generate_explore_link", {"request": request.model_dump()}
+ )
+
+ assert result.data["error"] is None
+ # x-axis should be normalized from 'orderdate' to 'OrderDate'
+ assert result.data["form_data"]["x_axis"] == "OrderDate"
+
+ @patch(
+
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
+ )
+ @patch("superset.daos.dataset.DatasetDAO.find_by_id")
+ @patch(
+
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
+ )
+ @pytest.mark.asyncio
+ async def test_filter_column_normalized_in_form_data(
+ self,
+ mock_create_form_data,
+ mock_find_dataset,
+ mock_get_context,
+ mcp_server,
+ ):
+ """Filter column name in wrong case is normalized in adhoc_filters."""
+ mock_create_form_data.return_value = "norm_test_key_2"
+ mock_find_dataset.return_value = _mock_dataset(id=18)
+ mock_get_context.return_value = DatasetContext(
+ id=18,
+ table_name="Vehicle Sales",
+ schema="public",
+ database_name="examples",
+ available_columns=[
+ {"name": "OrderDate", "type": "DATE", "is_temporal": True},
+ {"name": "Sales", "type": "FLOAT", "is_numeric": True},
+ ],
+ available_metrics=[],
+ )
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ filters=[
+ FilterConfig(column="orderdate", op=">", value="2023-01-01"),
+ ],
+ )
+ request = GenerateExploreLinkRequest(dataset_id="18", config=config)
+
+ async with Client(mcp_server) as client:
+ result = await client.call_tool(
+ "generate_explore_link", {"request": request.model_dump()}
+ )
+
+ assert result.data["error"] is None
+ form_data = result.data["form_data"]
+ # x-axis normalized
+ assert form_data["x_axis"] == "OrderDate"
+ # filter subject normalized to match x-axis
+ adhoc_filters = form_data.get("adhoc_filters", [])
+ assert len(adhoc_filters) == 1
+ assert adhoc_filters[0]["subject"] == "OrderDate"
+
+ @patch(
+
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
+ )
+ @patch("superset.daos.dataset.DatasetDAO.find_by_id")
+ @patch(
+
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
+ )
+ @pytest.mark.asyncio
+ async def test_normalization_fallback_when_dataset_not_found(
+ self,
+ mock_create_form_data,
+ mock_find_dataset,
+ mock_get_context,
+ mcp_server,
+ ):
+ """When dataset context is unavailable, original names pass through."""
+ mock_create_form_data.return_value = "norm_test_key_3"
+ mock_find_dataset.return_value = _mock_dataset(id=99)
+ mock_get_context.return_value = None
+
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="orderdate"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="line",
+ )
+ request = GenerateExploreLinkRequest(dataset_id="99", config=config)
+
+ async with Client(mcp_server) as client:
+ result = await client.call_tool(
+ "generate_explore_link", {"request": request.model_dump()}
+ )
+
+ assert result.data["error"] is None
+ # original names should pass through unchanged
+ assert result.data["form_data"]["x_axis"] == "orderdate"