This is an automated email from the ASF dual-hosted git repository.

kgabryje pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new a1312a86e88 fix(mcp): normalize column names to fix time series filter 
prompt issue (#37187)
a1312a86e88 is described below

commit a1312a86e88fb7f9cf6d9bff1e3627057f33abbd
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Wed Feb 25 09:27:53 2026 -0500

    fix(mcp): normalize column names to fix time series filter prompt issue 
(#37187)
    
    Co-authored-by: Claude Opus 4.5 <[email protected]>
---
 .../chart/validation/dataset_validator.py          | 153 ++++-
 superset/mcp_service/chart/validation/pipeline.py  |  86 ++-
 .../explore/tool/generate_explore_link.py          |  15 +-
 .../validation/test_column_name_normalization.py   | 681 +++++++++++++++++++++
 .../explore/tool/test_generate_explore_link.py     | 149 +++++
 5 files changed, 1073 insertions(+), 11 deletions(-)

diff --git a/superset/mcp_service/chart/validation/dataset_validator.py 
b/superset/mcp_service/chart/validation/dataset_validator.py
index b03d0ffe9c0..c50fccc6938 100644
--- a/superset/mcp_service/chart/validation/dataset_validator.py
+++ b/superset/mcp_service/chart/validation/dataset_validator.py
@@ -22,7 +22,7 @@ Validates that referenced columns exist in the dataset schema.
 
 import difflib
 import logging
-from typing import Dict, List, Tuple
+from typing import Any, Dict, List, Tuple
 
 from superset.mcp_service.chart.schemas import (
     ColumnRef,
@@ -37,13 +37,25 @@ from superset.mcp_service.common.error_schemas import (
 
 logger = logging.getLogger(__name__)
 
+# Exceptions that can occur during column name normalization.
+# Shared by the validation pipeline and tool-level normalization calls.
+NORMALIZATION_EXCEPTIONS = (
+    ImportError,
+    AttributeError,
+    KeyError,
+    ValueError,
+    TypeError,
+)
+
 
 class DatasetValidator:
     """Validates chart configuration against dataset schema."""
 
     @staticmethod
     def validate_against_dataset(
-        config: TableChartConfig | XYChartConfig, dataset_id: int | str
+        config: TableChartConfig | XYChartConfig,
+        dataset_id: int | str,
+        dataset_context: DatasetContext | None = None,
     ) -> Tuple[bool, ChartGenerationError | None]:
         """
         Validate chart configuration against dataset schema.
@@ -51,12 +63,15 @@ class DatasetValidator:
         Args:
             config: Chart configuration to validate
             dataset_id: Dataset ID to validate against
+            dataset_context: Pre-fetched dataset context to avoid duplicate
+                DB queries. If None, fetches from the database.
 
         Returns:
             Tuple of (is_valid, error)
         """
-        # Get dataset context
-        dataset_context = DatasetValidator._get_dataset_context(dataset_id)
+        # Get dataset context (reuse if provided)
+        if dataset_context is None:
+            dataset_context = DatasetValidator._get_dataset_context(dataset_id)
         if not dataset_context:
             from superset.mcp_service.utils.error_builder import (
                 ChartErrorBuilder,
@@ -198,6 +213,136 @@ class DatasetValidator:
 
         return False
 
+    @staticmethod
+    def _get_canonical_column_name(
+        column_name: str, dataset_context: DatasetContext
+    ) -> str:
+        """
+        Get the canonical column name from the dataset.
+
+        Performs case-insensitive matching and returns the actual column name
+        as stored in the dataset. This ensures column names in form_data match
+        exactly with what the frontend expects.
+
+        Args:
+            column_name: The column name to normalize
+            dataset_context: Dataset context with column information
+
+        Returns:
+            The canonical column name from the dataset, or the original name
+            if no match is found.
+        """
+        column_lower = column_name.lower()
+
+        # Check regular columns first
+        for col in dataset_context.available_columns:
+            if col["name"].lower() == column_lower:
+                return col["name"]
+
+        # Check metrics
+        for metric in dataset_context.available_metrics:
+            if metric["name"].lower() == column_lower:
+                return metric["name"]
+
+        # Return original if not found (validation should catch this case)
+        return column_name
+
+    @staticmethod
+    def _normalize_xy_config(
+        config_dict: Dict[str, Any], dataset_context: DatasetContext
+    ) -> None:
+        """Normalize column names in an XY chart config dict in place."""
+        # Normalize x-axis column
+        if "x" in config_dict and config_dict["x"]:
+            config_dict["x"]["name"] = 
DatasetValidator._get_canonical_column_name(
+                config_dict["x"]["name"], dataset_context
+            )
+
+        # Normalize y-axis columns
+        if "y" in config_dict and config_dict["y"]:
+            for y_col in config_dict["y"]:
+                y_col["name"] = DatasetValidator._get_canonical_column_name(
+                    y_col["name"], dataset_context
+                )
+
+        # Normalize group_by column
+        if "group_by" in config_dict and config_dict["group_by"]:
+            config_dict["group_by"]["name"] = (
+                DatasetValidator._get_canonical_column_name(
+                    config_dict["group_by"]["name"], dataset_context
+                )
+            )
+
+    @staticmethod
+    def _normalize_table_config(
+        config_dict: Dict[str, Any], dataset_context: DatasetContext
+    ) -> None:
+        """Normalize column names in a table chart config dict in place."""
+        if "columns" in config_dict and config_dict["columns"]:
+            for col in config_dict["columns"]:
+                col["name"] = DatasetValidator._get_canonical_column_name(
+                    col["name"], dataset_context
+                )
+
+    @staticmethod
+    def _normalize_filters(
+        config_dict: Dict[str, Any], dataset_context: DatasetContext
+    ) -> None:
+        """Normalize filter column names in a config dict in place."""
+        if "filters" in config_dict and config_dict["filters"]:
+            for filter_config in config_dict["filters"]:
+                if filter_config and "column" in filter_config:
+                    filter_config["column"] = (
+                        DatasetValidator._get_canonical_column_name(
+                            filter_config["column"], dataset_context
+                        )
+                    )
+
+    @staticmethod
+    def normalize_column_names(
+        config: TableChartConfig | XYChartConfig,
+        dataset_id: int | str,
+        dataset_context: DatasetContext | None = None,
+    ) -> TableChartConfig | XYChartConfig:
+        """
+        Normalize column names in config to match the canonical dataset column 
names.
+
+        This fixes case sensitivity issues where user-provided column names
+        (e.g., 'order_date') don't match exactly with the dataset column names
+        (e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
+        so we need to ensure column names match exactly.
+
+        Args:
+            config: Chart configuration with column references
+            dataset_id: Dataset ID to get canonical column names from
+            dataset_context: Pre-fetched dataset context to avoid duplicate
+                DB queries. If None, fetches from the database.
+
+        Returns:
+            A new config with normalized column names
+        """
+        if dataset_context is None:
+            dataset_context = DatasetValidator._get_dataset_context(dataset_id)
+        if not dataset_context:
+            return config
+
+        # Create a mutable copy of the config
+        config_dict = config.model_dump()
+
+        # Normalize based on config type
+        if isinstance(config, XYChartConfig):
+            DatasetValidator._normalize_xy_config(config_dict, dataset_context)
+        elif isinstance(config, TableChartConfig):
+            DatasetValidator._normalize_table_config(config_dict, 
dataset_context)
+
+        # Normalize filter columns (common to both config types)
+        DatasetValidator._normalize_filters(config_dict, dataset_context)
+
+        # Reconstruct the config with normalized names
+        if isinstance(config, XYChartConfig):
+            return XYChartConfig.model_validate(config_dict)
+        return TableChartConfig.model_validate(config_dict)
+
     @staticmethod
     def _get_column_suggestions(
         column_name: str, dataset_context: DatasetContext, max_suggestions: 
int = 3
diff --git a/superset/mcp_service/chart/validation/pipeline.py 
b/superset/mcp_service/chart/validation/pipeline.py
index 948f9d2e62d..a0f475ffa0e 100644
--- a/superset/mcp_service/chart/validation/pipeline.py
+++ b/superset/mcp_service/chart/validation/pipeline.py
@@ -27,7 +27,10 @@ from superset.mcp_service.chart.schemas import (
     ChartConfig,
     GenerateChartRequest,
 )
-from superset.mcp_service.common.error_schemas import ChartGenerationError
+from superset.mcp_service.common.error_schemas import (
+    ChartGenerationError,
+    DatasetContext,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -168,9 +171,14 @@ class ValidationPipeline:
             if request is None:
                 return ValidationResult(is_valid=False, error=error)
 
-            # Layer 2: Dataset validation
+            # Fetch dataset context once and reuse across validation layers
+            dataset_context = ValidationPipeline._get_dataset_context(
+                request.dataset_id
+            )
+
+            # Layer 2: Dataset validation (reuses context)
             is_valid, error = ValidationPipeline._validate_dataset(
-                request.config, request.dataset_id
+                request.config, request.dataset_id, dataset_context
             )
             if not is_valid:
                 return ValidationResult(is_valid=False, request=request, 
error=error)
@@ -181,8 +189,15 @@ class ValidationPipeline:
             )
             # Runtime validation always returns True now, warnings are 
informational
 
+            # Layer 4: Column name normalization (reuses context)
+            normalized_request = ValidationPipeline._normalize_column_names(
+                request, dataset_context
+            )
+
             return ValidationResult(
-                is_valid=True, request=request, warnings=warnings_metadata
+                is_valid=True,
+                request=normalized_request,
+                warnings=warnings_metadata,
             )
 
         except Exception as e:
@@ -201,15 +216,32 @@ class ValidationPipeline:
             )
             return ValidationResult(is_valid=False, error=error)
 
+    @staticmethod
+    def _get_dataset_context(
+        dataset_id: int | str,
+    ) -> DatasetContext | None:
+        """Fetch dataset context once to reuse across validation layers."""
+        try:
+            from .dataset_validator import DatasetValidator
+
+            return DatasetValidator._get_dataset_context(dataset_id)
+        except ImportError:
+            logger.warning("Dataset validator not available, skipping context 
fetch")
+            return None
+
     @staticmethod
     def _validate_dataset(
-        config: ChartConfig, dataset_id: int | str
+        config: ChartConfig,
+        dataset_id: int | str,
+        dataset_context: DatasetContext | None = None,
     ) -> Tuple[bool, ChartGenerationError | None]:
         """Validate configuration against dataset schema."""
         try:
             from .dataset_validator import DatasetValidator
 
-            return DatasetValidator.validate_against_dataset(config, 
dataset_id)
+            return DatasetValidator.validate_against_dataset(
+                config, dataset_id, dataset_context=dataset_context
+            )
         except ImportError:
             # Skip if dataset validator not available
             logger.warning(
@@ -248,6 +280,48 @@ class ValidationPipeline:
             # Don't fail on runtime validation errors
             return True, None
 
+    @staticmethod
+    def _normalize_column_names(
+        request: GenerateChartRequest,
+        dataset_context: DatasetContext | None = None,
+    ) -> GenerateChartRequest:
+        """
+        Normalize column names in the request to match canonical dataset names.
+
+        This fixes case sensitivity issues where user-provided column names
+        don't match exactly with the dataset column names. For example,
+        if a user provides 'order_date' but the dataset has 'OrderDate',
+        this method will normalize it to 'OrderDate'.
+
+        Args:
+            request: The validated chart generation request
+            dataset_context: Pre-fetched dataset context to avoid duplicate
+                DB queries. If None, fetches from the database.
+
+        Returns:
+            A new request with normalized column names
+        """
+        try:
+            from .dataset_validator import DatasetValidator
+
+            normalized_config = DatasetValidator.normalize_column_names(
+                request.config,
+                request.dataset_id,
+                dataset_context=dataset_context,
+            )
+
+            # Create a new request with the normalized config
+            request_dict = request.model_dump()
+            request_dict["config"] = normalized_config.model_dump()
+
+            return GenerateChartRequest.model_validate(request_dict)
+
+        except (ImportError, AttributeError, KeyError, ValueError, TypeError) 
as e:
+            # If normalization fails, return the original request
+            # Validation has already passed, so this is a non-critical failure
+            logger.warning("Column name normalization failed: %s", e)
+            return request
+
     @staticmethod
     def validate_filters(
         filters: List[Any],
diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py 
b/superset/mcp_service/explore/tool/generate_explore_link.py
index 3048a538c6c..dca721caa8b 100644
--- a/superset/mcp_service/explore/tool/generate_explore_link.py
+++ b/superset/mcp_service/explore/tool/generate_explore_link.py
@@ -93,9 +93,22 @@ async def generate_explore_link(
     try:
         await ctx.report_progress(1, 3, "Converting configuration to form 
data")
         with 
event_logger.log_context(action="mcp.generate_explore_link.form_data"):
+            # Normalize column names to match canonical dataset column names
+            # This fixes case sensitivity issues (e.g., 'order_date' vs 
'OrderDate')
+            try:
+                from superset.mcp_service.chart.validation.dataset_validator 
import (
+                    DatasetValidator,
+                )
+
+                normalized_config = DatasetValidator.normalize_column_names(
+                    request.config, request.dataset_id
+                )
+            except (ImportError, AttributeError, KeyError, ValueError, 
TypeError):
+                normalized_config = request.config
+
             # Map config to form_data using shared utilities
             form_data = map_config_to_form_data(
-                request.config, dataset_id=request.dataset_id
+                normalized_config, dataset_id=request.dataset_id
             )
 
         # Add datasource to form_data for consistency with generate_chart
diff --git 
a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
 
b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
new file mode 100644
index 00000000000..77fdf64143f
--- /dev/null
+++ 
b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
@@ -0,0 +1,681 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for column name normalization in the MCP service.
+
+This tests the fix for the issue where time series charts would incorrectly
+prompt to add the x-axis to filters when the column name case didn't match
+exactly (e.g., 'order_date' vs 'OrderDate').
+"""
+
+from typing import Any, Dict
+from unittest.mock import patch
+
+import pytest
+
+from superset.mcp_service.chart.schemas import (
+    ColumnRef,
+    FilterConfig,
+    TableChartConfig,
+    XYChartConfig,
+)
+from superset.mcp_service.chart.validation.dataset_validator import 
DatasetValidator
+from superset.mcp_service.common.error_schemas import DatasetContext
+
+
[email protected]
+def mock_dataset_context() -> DatasetContext:
+    """Create a mock dataset context with mixed-case column names."""
+    return DatasetContext(
+        id=18,
+        table_name="Vehicle Sales",
+        schema="public",
+        database_name="examples",
+        available_columns=[
+            {"name": "OrderDate", "type": "DATE", "is_temporal": True},
+            {"name": "ProductLine", "type": "VARCHAR", "is_temporal": False},
+            {"name": "Sales", "type": "DECIMAL", "is_numeric": True},
+            {"name": "quantity_ordered", "type": "INTEGER", "is_numeric": 
True},
+        ],
+        available_metrics=[
+            {"name": "TotalRevenue", "expression": "SUM(Sales)", 
"description": None},
+        ],
+    )
+
+
+class TestGetCanonicalColumnName:
+    """Test _get_canonical_column_name static method."""
+
+    def test_exact_match_returns_same_name(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that exact match returns the same column name."""
+        result = DatasetValidator._get_canonical_column_name(
+            "OrderDate", mock_dataset_context
+        )
+        assert result == "OrderDate"
+
+    def test_lowercase_returns_canonical_name(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that lowercase input returns the canonical (dataset) column 
name."""
+        result = DatasetValidator._get_canonical_column_name(
+            "orderdate", mock_dataset_context
+        )
+        assert result == "OrderDate"
+
+    def test_snake_case_returns_canonical_name(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that snake_case input returns the canonical column name."""
+        # 'order_date' won't match 'OrderDate' directly, but would match if
+        # the dataset had 'order_date'. This test verifies case-insensitive 
matching.
+        result = DatasetValidator._get_canonical_column_name(
+            "productline", mock_dataset_context
+        )
+        assert result == "ProductLine"
+
+    def test_uppercase_returns_canonical_name(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that uppercase input returns the canonical column name."""
+        result = DatasetValidator._get_canonical_column_name(
+            "SALES", mock_dataset_context
+        )
+        assert result == "Sales"
+
+    def test_metric_name_normalization(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that metric names are also normalized."""
+        result = DatasetValidator._get_canonical_column_name(
+            "totalrevenue", mock_dataset_context
+        )
+        assert result == "TotalRevenue"
+
+    def test_unknown_column_returns_original(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that unknown columns return the original name."""
+        result = DatasetValidator._get_canonical_column_name(
+            "unknown_column", mock_dataset_context
+        )
+        assert result == "unknown_column"
+
+
+class TestNormalizeXYConfig:
+    """Test _normalize_xy_config static method."""
+
+    def test_normalize_x_axis_column(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that x-axis column name is normalized."""
+        config_dict: Dict[str, Any] = {
+            "chart_type": "xy",
+            "x": {"name": "orderdate"},
+            "y": [{"name": "Sales", "aggregate": "SUM"}],
+            "kind": "line",
+        }
+
+        DatasetValidator._normalize_xy_config(config_dict, 
mock_dataset_context)
+
+        assert config_dict["x"]["name"] == "OrderDate"
+
+    def test_normalize_y_axis_columns(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that y-axis column names are normalized."""
+        config_dict: Dict[str, Any] = {
+            "chart_type": "xy",
+            "x": {"name": "OrderDate"},
+            "y": [
+                {"name": "sales", "aggregate": "SUM"},
+                {"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
+            ],
+            "kind": "bar",
+        }
+
+        DatasetValidator._normalize_xy_config(config_dict, 
mock_dataset_context)
+
+        assert config_dict["y"][0]["name"] == "Sales"
+        assert config_dict["y"][1]["name"] == "quantity_ordered"
+
+    def test_normalize_group_by_column(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that group_by column name is normalized."""
+        config_dict: Dict[str, Any] = {
+            "chart_type": "xy",
+            "x": {"name": "OrderDate"},
+            "y": [{"name": "Sales", "aggregate": "SUM"}],
+            "kind": "line",
+            "group_by": {"name": "productline"},
+        }
+
+        DatasetValidator._normalize_xy_config(config_dict, 
mock_dataset_context)
+
+        assert config_dict["group_by"]["name"] == "ProductLine"
+
+
+class TestNormalizeTableConfig:
+    """Test _normalize_table_config static method."""
+
+    def test_normalize_table_columns(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that table column names are normalized."""
+        config_dict: Dict[str, Any] = {
+            "chart_type": "table",
+            "columns": [
+                {"name": "orderdate"},
+                {"name": "PRODUCTLINE"},
+                {"name": "sales", "aggregate": "SUM"},
+            ],
+        }
+
+        DatasetValidator._normalize_table_config(config_dict, 
mock_dataset_context)
+
+        assert config_dict["columns"][0]["name"] == "OrderDate"
+        assert config_dict["columns"][1]["name"] == "ProductLine"
+        assert config_dict["columns"][2]["name"] == "Sales"
+
+
+class TestNormalizeFilters:
+    """Test _normalize_filters static method."""
+
+    def test_normalize_filter_columns(
+        self, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that filter column names are normalized."""
+        config_dict: Dict[str, Any] = {
+            "filters": [
+                {"column": "productline", "op": "=", "value": "Classic Cars"},
+                {"column": "ORDERDATE", "op": ">", "value": "2023-01-01"},
+            ],
+        }
+
+        DatasetValidator._normalize_filters(config_dict, mock_dataset_context)
+
+        assert config_dict["filters"][0]["column"] == "ProductLine"
+        assert config_dict["filters"][1]["column"] == "OrderDate"
+
+
+class TestNormalizeColumnNames:
+    """Test the main normalize_column_names method."""
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_normalize_xy_chart_config(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test full normalization of XY chart config."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),  # lowercase - should normalize to 
OrderDate
+            y=[
+                ColumnRef(name="sales", aggregate="SUM")
+            ],  # lowercase - should normalize to Sales
+            kind="line",
+            filters=[FilterConfig(column="productline", op="=", value="Classic 
Cars")],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.x.name == "OrderDate"
+        assert normalized.y[0].name == "Sales"
+        assert normalized.filters is not None
+        assert normalized.filters[0].column == "ProductLine"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_normalize_table_chart_config(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test full normalization of table chart config."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = TableChartConfig(
+            chart_type="table",
+            columns=[
+                ColumnRef(name="orderdate"),
+                ColumnRef(name="productline"),
+                ColumnRef(name="sales", aggregate="SUM"),
+            ],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.columns[0].name == "OrderDate"
+        assert normalized.columns[1].name == "ProductLine"
+        assert normalized.columns[2].name == "Sales"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_returns_original_when_dataset_not_found(self, mock_get_context) 
-> None:
+        """Test that original config is returned when dataset context is 
unavailable."""
+        mock_get_context.return_value = None
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=999)
+
+        # Should return original config unchanged
+        assert normalized.x.name == "orderdate"
+        assert normalized.y[0].name == "sales"
+
+
+class TestTimeSeriesFilterPromptFix:
+    """Test the fix for time series charts incorrectly prompting x-axis 
filters."""
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_x_axis_matches_existing_filter_after_normalization(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """
+        Test the core fix: when creating a time series chart with
+        'order_date' as x-axis, and there's already a filter with
+        'OrderDate', after normalization they should match.
+
+        This is the exact scenario from the bug report where:
+        - User creates chart with x_axis = 'order_date'
+        - Dataset has column named 'OrderDate'
+        - Existing filter has subject = 'OrderDate'
+        - Without normalization: 'order_date' != 'OrderDate' -> prompt shown
+        - With normalization: 'OrderDate' == 'OrderDate' -> no prompt
+        """
+        mock_get_context.return_value = mock_dataset_context
+
+        # Simulate what the MCP service receives from user
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),  # User provides lowercase
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+            # Simulating an existing filter with the canonical name
+            filters=[
+                FilterConfig(column="OrderDate", op=">", value="2023-01-01"),
+            ],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        # After normalization, x.name should match the filter column exactly
+        assert normalized.x.name == "OrderDate"
+        assert normalized.filters is not None
+        assert normalized.filters[0].column == "OrderDate"
+
+        # This equality is what the frontend checks - now they match!
+        assert normalized.x.name == normalized.filters[0].column
+
+
[email protected]
+def uppercase_dataset_context() -> DatasetContext:
+    """Create a mock dataset context with all-uppercase column names (like 
flights)."""
+    return DatasetContext(
+        id=24,
+        table_name="flights",
+        schema="public",
+        database_name="examples",
+        available_columns=[
+            {"name": "DEPARTURE_DELAY", "type": "FLOAT", "is_numeric": True},
+            {"name": "ARRIVAL_DELAY", "type": "FLOAT", "is_numeric": True},
+            {"name": "DISTANCE", "type": "BIGINT", "is_numeric": True},
+            {"name": "AIRLINE", "type": "VARCHAR", "is_temporal": False},
+            {"name": "ds", "type": "TIMESTAMP", "is_temporal": True},
+        ],
+        available_metrics=[
+            {"name": "count", "expression": "COUNT(*)", "description": None},
+        ],
+    )
+
+
+class TestNormalizeMultipleYAxisColumns:
+    """Test normalization of multiple y-axis columns."""
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_normalize_multiple_y_columns(
+        self, mock_get_context, uppercase_dataset_context: DatasetContext
+    ) -> None:
+        """Test that all y-axis columns are normalized."""
+        mock_get_context.return_value = uppercase_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ds"),
+            y=[
+                ColumnRef(name="departure_delay", aggregate="AVG"),
+                ColumnRef(name="arrival_delay", aggregate="AVG"),
+            ],
+            kind="area",
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=24)
+
+        assert normalized.y[0].name == "DEPARTURE_DELAY"
+        assert normalized.y[1].name == "ARRIVAL_DELAY"
+
+
+class TestNormalizeUppercaseDataset:
+    """Test normalization against dataset with all-uppercase column names."""
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_lowercase_to_uppercase(
+        self, mock_get_context, uppercase_dataset_context: DatasetContext
+    ) -> None:
+        """Test lowercase input normalizes to uppercase canonical names."""
+        mock_get_context.return_value = uppercase_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ds"),
+            y=[ColumnRef(name="distance", aggregate="AVG")],
+            kind="bar",
+            group_by=ColumnRef(name="airline"),
+            filters=[FilterConfig(column="airline", op="=", value="AA")],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=24)
+
+        assert normalized.x.name == "ds"
+        assert normalized.y[0].name == "DISTANCE"
+        assert normalized.group_by is not None
+        assert normalized.group_by.name == "AIRLINE"
+        assert normalized.filters is not None
+        assert normalized.filters[0].column == "AIRLINE"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_exact_match_preserved(
+        self, mock_get_context, uppercase_dataset_context: DatasetContext
+    ) -> None:
+        """Test that already-correct names are preserved unchanged."""
+        mock_get_context.return_value = uppercase_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ds"),
+            y=[ColumnRef(name="DEPARTURE_DELAY", aggregate="AVG")],
+            kind="line",
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=24)
+
+        assert normalized.x.name == "ds"
+        assert normalized.y[0].name == "DEPARTURE_DELAY"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_metric_normalized_in_y_axis(
+        self, mock_get_context, uppercase_dataset_context: DatasetContext
+    ) -> None:
+        """Test that metric names used in y-axis are normalized."""
+        mock_get_context.return_value = uppercase_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ds"),
+            y=[ColumnRef(name="COUNT", aggregate="SUM")],
+            kind="bar",
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=24)
+
+        # 'COUNT' should normalize to 'count' (the metric name)
+        assert normalized.y[0].name == "count"
+
+
+class TestNormalizeEdgeCases:
+    """Test edge cases for column name normalization."""
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_config_with_no_filters(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test normalization when config has no filters."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.x.name == "OrderDate"
+        assert normalized.y[0].name == "Sales"
+        assert normalized.filters is None
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_config_with_empty_filters(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test normalization when config has empty filters list."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+            filters=[],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.x.name == "OrderDate"
+        assert normalized.filters is not None
+        assert len(normalized.filters) == 0
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_config_with_no_group_by(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test normalization when config has no group_by."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="bar",
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.x.name == "OrderDate"
+        assert normalized.group_by is None
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_all_fields_normalized_together(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that x, y, group_by, and filters are all normalized in one 
call."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ORDERDATE"),
+            y=[
+                ColumnRef(name="sales", aggregate="SUM"),
+                ColumnRef(name="QUANTITY_ORDERED", aggregate="COUNT"),
+            ],
+            kind="bar",
+            group_by=ColumnRef(name="PRODUCTLINE"),
+            filters=[
+                FilterConfig(column="productline", op="=", value="Classic 
Cars"),
+                FilterConfig(column="ORDERDATE", op=">", value="2023-01-01"),
+            ],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.x.name == "OrderDate"
+        assert normalized.y[0].name == "Sales"
+        assert normalized.y[1].name == "quantity_ordered"
+        assert normalized.group_by is not None
+        assert normalized.group_by.name == "ProductLine"
+        assert normalized.filters is not None
+        assert normalized.filters[0].column == "ProductLine"
+        assert normalized.filters[1].column == "OrderDate"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_normalization_is_idempotent(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that normalizing already-normalized config returns same 
result."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+            filters=[FilterConfig(column="productline", op="=", value="Cars")],
+        )
+
+        first = DatasetValidator.normalize_column_names(config, dataset_id=18)
+        second = DatasetValidator.normalize_column_names(first, dataset_id=18)
+
+        assert first.x.name == second.x.name == "OrderDate"
+        assert first.y[0].name == second.y[0].name == "Sales"
+        assert first.filters is not None
+        assert second.filters is not None
+        assert first.filters[0].column == second.filters[0].column == 
"ProductLine"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_aggregate_preserved_after_normalization(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that aggregate functions are preserved during normalization."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[
+                ColumnRef(name="sales", aggregate="SUM"),
+                ColumnRef(name="QUANTITY_ORDERED", aggregate="AVG"),
+            ],
+            kind="bar",
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.y[0].aggregate == "SUM"
+        assert normalized.y[1].aggregate == "AVG"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_filter_operator_and_value_preserved(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Test that filter op and value are preserved during normalization."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+            filters=[
+                FilterConfig(column="ORDERDATE", op=">=", value="2023-01-01"),
+                FilterConfig(column="sales", op=">", value=1000),
+            ],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.filters is not None
+        assert normalized.filters[0].column == "OrderDate"
+        assert normalized.filters[0].op == ">="
+        assert normalized.filters[0].value == "2023-01-01"
+        assert normalized.filters[1].column == "Sales"
+        assert normalized.filters[1].op == ">"
+        assert normalized.filters[1].value == 1000
+
+
+class TestNormalizeXAxisFilterConsistency:
+    """Test that x-axis and filter column names are consistent after 
normalization.
+
+    These tests verify the core bug fix: when x-axis and filter reference
+    the same column but with different cases, normalization ensures they match.
+    """
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_both_wrong_case_normalized_to_same(
+        self, mock_get_context, mock_dataset_context: DatasetContext
+    ) -> None:
+        """Both x-axis and filter in wrong case normalize to same canonical 
name."""
+        mock_get_context.return_value = mock_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ORDERDATE"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+            filters=[FilterConfig(column="orderdate", op=">", 
value="2023-01-01")],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=18)
+
+        assert normalized.filters is not None
+        assert normalized.x.name == normalized.filters[0].column == "OrderDate"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_uppercase_dataset_x_filter_match(
+        self, mock_get_context, uppercase_dataset_context: DatasetContext
+    ) -> None:
+        """On uppercase-column dataset, both lowercase refs normalize to 
uppercase."""
+        mock_get_context.return_value = uppercase_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ds"),
+            y=[ColumnRef(name="departure_delay", aggregate="AVG")],
+            kind="line",
+            filters=[FilterConfig(column="ds", op=">", value="2015-01-01")],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=24)
+
+        assert normalized.filters is not None
+        assert normalized.x.name == normalized.filters[0].column == "ds"
+
+    @patch.object(DatasetValidator, "_get_dataset_context")
+    def test_group_by_matches_filter_after_normalization(
+        self, mock_get_context, uppercase_dataset_context: DatasetContext
+    ) -> None:
+        """group_by and filter for same column normalize to same canonical 
name."""
+        mock_get_context.return_value = uppercase_dataset_context
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="ds"),
+            y=[ColumnRef(name="distance", aggregate="AVG")],
+            kind="bar",
+            group_by=ColumnRef(name="Airline"),
+            filters=[FilterConfig(column="airline", op="=", value="AA")],
+        )
+
+        normalized = DatasetValidator.normalize_column_names(config, 
dataset_id=24)
+
+        assert normalized.group_by is not None
+        assert normalized.filters is not None
+        assert normalized.group_by.name == normalized.filters[0].column == 
"AIRLINE"
diff --git 
a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py 
b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py
index af08834d57b..0a8771e48ba 100644
--- a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py
+++ b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py
@@ -35,6 +35,7 @@ from superset.mcp_service.chart.schemas import (
     TableChartConfig,
     XYChartConfig,
 )
+from superset.mcp_service.common.error_schemas import DatasetContext
 
 logging.basicConfig(level=logging.DEBUG)
 logger = logging.getLogger(__name__)
@@ -706,3 +707,151 @@ class TestGenerateExploreLink:
             assert result.data["form_data"].get("x_axis") == "date"
             # Verify datasource field format: "{dataset_id}__table"
             assert result.data["form_data"].get("datasource") == "1__table"
+
+
+class TestGenerateExploreLinkColumnNormalization:
+    """Tests that generate_explore_link normalizes column names.
+
+    This verifies the fix where user-provided column names in wrong case
+    (e.g., 'order_date') are normalized to the canonical dataset name
+    (e.g., 'OrderDate') before being used in form_data.
+    """
+
+    @patch(
+        
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
+    )
+    @patch("superset.daos.dataset.DatasetDAO.find_by_id")
+    @patch(
+        
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
+    )
+    @pytest.mark.asyncio
+    async def test_xy_chart_x_axis_normalized_in_form_data(
+        self,
+        mock_create_form_data,
+        mock_find_dataset,
+        mock_get_context,
+        mcp_server,
+    ):
+        """x-axis column name in wrong case is normalized in form_data."""
+        mock_create_form_data.return_value = "norm_test_key_1"
+        mock_find_dataset.return_value = _mock_dataset(id=18)
+        mock_get_context.return_value = DatasetContext(
+            id=18,
+            table_name="Vehicle Sales",
+            schema="public",
+            database_name="examples",
+            available_columns=[
+                {"name": "OrderDate", "type": "DATE", "is_temporal": True},
+                {"name": "Sales", "type": "FLOAT", "is_numeric": True},
+            ],
+            available_metrics=[],
+        )
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+        )
+        request = GenerateExploreLinkRequest(dataset_id="18", config=config)
+
+        async with Client(mcp_server) as client:
+            result = await client.call_tool(
+                "generate_explore_link", {"request": request.model_dump()}
+            )
+
+            assert result.data["error"] is None
+            # x-axis should be normalized from 'orderdate' to 'OrderDate'
+            assert result.data["form_data"]["x_axis"] == "OrderDate"
+
+    @patch(
+        
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
+    )
+    @patch("superset.daos.dataset.DatasetDAO.find_by_id")
+    @patch(
+        
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
+    )
+    @pytest.mark.asyncio
+    async def test_filter_column_normalized_in_form_data(
+        self,
+        mock_create_form_data,
+        mock_find_dataset,
+        mock_get_context,
+        mcp_server,
+    ):
+        """Filter column name in wrong case is normalized in adhoc_filters."""
+        mock_create_form_data.return_value = "norm_test_key_2"
+        mock_find_dataset.return_value = _mock_dataset(id=18)
+        mock_get_context.return_value = DatasetContext(
+            id=18,
+            table_name="Vehicle Sales",
+            schema="public",
+            database_name="examples",
+            available_columns=[
+                {"name": "OrderDate", "type": "DATE", "is_temporal": True},
+                {"name": "Sales", "type": "FLOAT", "is_numeric": True},
+            ],
+            available_metrics=[],
+        )
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+            filters=[
+                FilterConfig(column="orderdate", op=">", value="2023-01-01"),
+            ],
+        )
+        request = GenerateExploreLinkRequest(dataset_id="18", config=config)
+
+        async with Client(mcp_server) as client:
+            result = await client.call_tool(
+                "generate_explore_link", {"request": request.model_dump()}
+            )
+
+            assert result.data["error"] is None
+            form_data = result.data["form_data"]
+            # x-axis normalized
+            assert form_data["x_axis"] == "OrderDate"
+            # filter subject normalized to match x-axis
+            adhoc_filters = form_data.get("adhoc_filters", [])
+            assert len(adhoc_filters) == 1
+            assert adhoc_filters[0]["subject"] == "OrderDate"
+
+    @patch(
+        
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
+    )
+    @patch("superset.daos.dataset.DatasetDAO.find_by_id")
+    @patch(
+        
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
+    )
+    @pytest.mark.asyncio
+    async def test_normalization_fallback_when_dataset_not_found(
+        self,
+        mock_create_form_data,
+        mock_find_dataset,
+        mock_get_context,
+        mcp_server,
+    ):
+        """When dataset context is unavailable, original names pass through."""
+        mock_create_form_data.return_value = "norm_test_key_3"
+        mock_find_dataset.return_value = _mock_dataset(id=99)
+        mock_get_context.return_value = None
+
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="orderdate"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="line",
+        )
+        request = GenerateExploreLinkRequest(dataset_id="99", config=config)
+
+        async with Client(mcp_server) as client:
+            result = await client.call_tool(
+                "generate_explore_link", {"request": request.model_dump()}
+            )
+
+            assert result.data["error"] is None
+            # original names should pass through unchanged
+            assert result.data["form_data"]["x_axis"] == "orderdate"

Reply via email to