andygrove commented on code in PR #21508:
URL: https://github.com/apache/datafusion/pull/21508#discussion_r3068302707


##########
datafusion/spark/scripts/validate_slt.py:
##########
@@ -0,0 +1,1210 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Validate hardcoded expected values in .slt (sqllogictest) test files
+by running the same queries against PySpark and comparing results.
+
+Usage:
+    python validate_slt.py                          # Run all .slt files
+    python validate_slt.py --path math/abs.slt      # Single file
+    python validate_slt.py --path string/           # All files in subdirectory
+    python validate_slt.py --verbose                 # Show details
+    python validate_slt.py --show-skipped            # Show skipped queries
+"""
+
+import argparse
+import math
+import os
+import re
+import sys
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Optional
+
+# ---------------------------------------------------------------------------
+# Arrow type -> Spark type mapping
+# ---------------------------------------------------------------------------
+ARROW_TO_SPARK_TYPE = {
+    "Int8": "TINYINT",
+    "Int16": "SMALLINT",
+    "Int32": "INT",
+    "Int64": "BIGINT",
+    "UInt8": "SMALLINT",
+    "UInt16": "INT",
+    "UInt32": "BIGINT",
+    "UInt64": "BIGINT",
+    "Float16": "FLOAT",
+    "Float32": "FLOAT",
+    "Float64": "DOUBLE",
+    "Utf8": "STRING",
+    "Boolean": "BOOLEAN",
+    "Binary": "BINARY",
+    "Date32": "DATE",
+    "Date64": "DATE",
+}
+
+# DataFusion cast type -> Spark type mapping
+DF_TO_SPARK_CAST_TYPE = {
+    "TINYINT": "TINYINT",
+    "SMALLINT": "SMALLINT",
+    "INT": "INT",
+    "INTEGER": "INT",
+    "BIGINT": "BIGINT",
+    "FLOAT": "FLOAT",
+    "REAL": "FLOAT",
+    "DOUBLE": "DOUBLE",
+    "STRING": "STRING",
+    "VARCHAR": "STRING",
+    "TEXT": "STRING",
+    "BOOLEAN": "BOOLEAN",
+    "BINARY": "BINARY",
+    "DATE": "DATE",
+    "TIMESTAMP": "TIMESTAMP",
+    # PostgreSQL-style aliases used in some .slt files
+    "FLOAT8": "DOUBLE",
+    "FLOAT4": "FLOAT",
+    "INT8": "BIGINT",
+    "INT4": "INT",
+    "INT2": "SMALLINT",
+    "BYTEA": "BINARY",
+}
+
+# Unsupported Arrow types for Spark
+UNSUPPORTED_ARROW_TYPES = {
+    "Utf8View",
+    "LargeUtf8",
+    "LargeBinary",
+    "BinaryView",
+}
+
+# ---------------------------------------------------------------------------
+# SLT record types
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class QueryRecord:
+    """A 'query <TYPE_CODES> [rowsort]' block."""
+
+    type_codes: str
+    sql: str
+    expected: list[str]
+    rowsort: bool
+    line_number: int
+    in_ansi_block: bool = False
+
+
+@dataclass
+class ErrorRecord:
+    """A 'query error <pattern>' or 'statement error <pattern>' block."""
+
+    pattern: str
+    sql: str
+    line_number: int
+    kind: str = "query"  # "query" or "statement"
+    in_ansi_block: bool = False
+
+
+@dataclass
+class StatementRecord:
+    """A 'statement ok' block (DDL/config)."""
+
+    sql: str
+    line_number: int
+    in_ansi_block: bool = False
+
+
+# ---------------------------------------------------------------------------
+# 1. SLT Parser
+# ---------------------------------------------------------------------------
+
+
+def parse_slt(filepath: str) -> list:
+    """Parse an .slt file into a list of records."""
+    with open(filepath) as f:
+        lines = f.readlines()
+
+    records = []
+    i = 0
+    in_ansi_mode = False
+
+    while i < len(lines):
+        line = lines[i].rstrip("\n")
+
+        # Skip blank lines and comments
+        if not line.strip() or line.strip().startswith("#"):
+            i += 1
+            continue
+
+        # query error <pattern>
+        m = re.match(r"^query\s+error\s+(.*)", line)
+        if m:
+            pattern = m.group(1).strip()
+            line_num = i + 1
+            i += 1
+            sql_lines = []
+            while i < len(lines) and lines[i].strip() and not 
lines[i].strip().startswith("#"):
+                stripped = lines[i].rstrip("\n")
+                if (
+                    re.match(r"^query\s", stripped)
+                    or re.match(r"^statement\s", stripped)
+                ):
+                    break
+                sql_lines.append(stripped)
+                i += 1
+            records.append(
+                ErrorRecord(
+                    pattern=pattern,
+                    sql="\n".join(sql_lines),
+                    line_number=line_num,
+                    kind="query",
+                    in_ansi_block=in_ansi_mode,
+                )
+            )
+            continue
+
+        # statement error <pattern>
+        m = re.match(r"^statement\s+error\s*(.*)", line)
+        if m:
+            pattern = m.group(1).strip()
+            line_num = i + 1
+            i += 1
+            sql_lines = []
+            while i < len(lines) and lines[i].strip() and not 
lines[i].strip().startswith("#"):
+                stripped = lines[i].rstrip("\n")
+                if (
+                    re.match(r"^query\s", stripped)
+                    or re.match(r"^statement\s", stripped)
+                ):
+                    break
+                sql_lines.append(stripped)
+                i += 1
+            records.append(
+                ErrorRecord(
+                    pattern=pattern,
+                    sql="\n".join(sql_lines),
+                    line_number=line_num,
+                    kind="statement",
+                    in_ansi_block=in_ansi_mode,
+                )
+            )
+            continue
+
+        # statement ok
+        m = re.match(r"^statement\s+ok\s*$", line)
+        if m:
+            line_num = i + 1
+            i += 1
+            sql_lines = []
+            while i < len(lines) and lines[i].strip() and not 
lines[i].strip().startswith("#"):
+                stripped = lines[i].rstrip("\n")
+                if (
+                    re.match(r"^query\s", stripped)
+                    or re.match(r"^statement\s", stripped)
+                ):
+                    break
+                sql_lines.append(stripped)
+                i += 1
+            sql = "\n".join(sql_lines)
+
+            # Track ANSI mode from statements
+            if re.search(
+                r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*true",
+                sql,
+                re.IGNORECASE,
+            ):
+                in_ansi_mode = True
+            elif re.search(
+                r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*false",
+                sql,
+                re.IGNORECASE,
+            ):
+                in_ansi_mode = False
+
+            records.append(
+                StatementRecord(
+                    sql=sql, line_number=line_num, in_ansi_block=in_ansi_mode
+                )
+            )
+            continue
+
+        # query <TYPE_CODES> [rowsort]
+        m = re.match(r"^query\s+(\S+)(\s+rowsort)?\s*$", line)
+        if m:
+            type_codes = m.group(1)
+            rowsort = m.group(2) is not None
+            line_num = i + 1
+            i += 1
+
+            # Collect SQL lines until ----
+            sql_lines = []
+            while i < len(lines) and lines[i].rstrip("\n") != "----":
+                sql_lines.append(lines[i].rstrip("\n"))
+                i += 1
+
+            # Skip the ---- separator
+            if i < len(lines) and lines[i].rstrip("\n") == "----":
+                i += 1
+
+            # Collect expected result lines until blank line or next record.
+            # Note: do NOT treat # as a comment here — result values can
+            # start with # (e.g., soundex('#') -> '#').
+            expected = []
+            while i < len(lines):
+                result_line = lines[i].rstrip("\n")
+                if result_line == "":
+                    i += 1
+                    break
+                if re.match(r"^(query|statement)\s", result_line):
+                    break
+                # A ## comment line in the results section signals end of 
results
+                if result_line.startswith("##"):
+                    break
+                expected.append(result_line)
+                i += 1
+
+            records.append(
+                QueryRecord(
+                    type_codes=type_codes,
+                    sql="\n".join(sql_lines),
+                    expected=expected,
+                    rowsort=rowsort,
+                    line_number=line_num,
+                    in_ansi_block=in_ansi_mode,
+                )
+            )
+            continue
+
+        # Unknown line, skip
+        i += 1
+
+    return records
+
+
+# ---------------------------------------------------------------------------
+# 2. SQL Translator (DataFusion -> PySpark)

Review Comment:
   Probably makes sense. I have not used SQLGlot, but can look into this for a 
future PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to