cetingokhan commented on code in PR #62963: URL: https://github.com/apache/airflow/pull/62963#discussion_r3077684100
########## providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py: ########## @@ -0,0 +1,908 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +SQL-based data-quality plan generation and execution. + +:class:`SQLDQPlanner` is the single entry-point for all SQL DQ logic. +It is deliberately kept separate from the operator so it can be unit-tested +without an Airflow context and later swapped for GEX/SODA planners without +touching the operator. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterator, Sequence +from contextlib import closing +from typing import TYPE_CHECKING, Any + +try: + from airflow.providers.common.ai.utils.sql_validation import ( + DEFAULT_ALLOWED_TYPES, + SQLSafetyError, + validate_sql as _validate_sql, + ) +except ImportError as e: + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + +from airflow.providers.common.ai.utils.db_schema import build_schema_context, resolve_dialect +from airflow.providers.common.ai.utils.dq_models import DQCheckGroup, DQPlan, RowLevelResult, UnexpectedResult +from airflow.providers.common.ai.utils.logging import log_run_summary + +if TYPE_CHECKING: + from pydantic_ai import Agent + from pydantic_ai.messages import ModelMessage + + from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.datafusion.engine import DataFusionEngine + from airflow.providers.common.sql.hooks.sql import DbApiHook + +log = logging.getLogger(__name__) + +_MAX_CHECKS_PER_GROUP = 5 +# Maximum rows fetched from DB per chunk during row-level processing — avoids loading the +# entire result set into memory at once. +_ROW_LEVEL_CHUNK_SIZE = 10_000 +# Hard cap on violation samples stored per check — independent of SQL LIMIT and chunk size. +_MAX_VIOLATION_SAMPLES = 100 + +_PLANNING_SYSTEM_PROMPT = """\ +You are a data-quality SQL expert. + +Given a set of named data-quality checks and a database schema, produce a \ +DQPlan that minimises the number of SQL queries while keeping each group \ +focused and manageable. + +GROUPING STRATEGY (multi-dimensional): + Group checks by **(target_table, check_category)**. Checks on the same table + that belong to different categories MUST be in separate groups. + + Allowed check_category values (assign one per check based on its description): + - null_check — null / missing value counts or percentages + - uniqueness — duplicate detection, cardinality checks + - validity — regex / format / pattern matching on string columns + - numeric_range — range, bounds, or statistical checks on numeric columns + - row_count — total row counts or existence checks + - string_format — length, encoding, whitespace, or character-set checks + - row_level — per-row or anomaly checks that evaluate individual records + + Row-level checks still follow the same grouping rule: group by (target_table, check_category="row_level"). + MAX {max_checks_per_group} CHECKS PER GROUP: + If a (table, category) pair has more than {max_checks_per_group} checks, + split them into sub-groups of at most {max_checks_per_group}. + + GROUP-ID NAMING: + Use the pattern "{{table}}_{{category}}_{{part}}". + Examples: customers_null_check_1, orders_validity_1, orders_validity_2 + + RATIONALE: + Keeping string-column checks (validity, string_format) apart from + numeric-column checks (numeric_range, null_check on numbers) produces + simpler SQL and makes failures easier to diagnose. + + CORRECT (two groups for same table, different categories): + Group customers_null_check_1: + SELECT + (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS null_email_pct, + (COUNT(CASE WHEN name IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS null_name_pct + FROM customers + + Group customers_validity_1: + SELECT + COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS invalid_phone_fmt + FROM customers + + WRONG (mixing null-check and regex-validity in one group): + SELECT + (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS null_email_pct, + COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS invalid_phone_fmt + FROM customers + +OUTPUT RULES: + 1. Each output column must be aliased to exactly the metric_key of its check. + Example: ... AS null_email_pct + 2. Each check_name must exactly match the key in the prompts dict. + 3. metric_key values must be valid SQL column aliases (snake_case, no spaces). + 4. Generates only SELECT queries — no INSERT, UPDATE, DELETE, DROP, or DDL. + 5. Use {dialect} syntax. + 6. Each check must appear in exactly ONE group. + 7. Each check must have a check_category from the allowed list above. + 8. Return a valid DQPlan object. No extra commentary. +""" + +_DATAFUSION_SYNTAX_SECTION = """\ + +DATAFUSION SQL SYNTAX RULES: + The target engine is Apache DataFusion. Observe these syntax differences + from standard PostgreSQL / ANSI SQL: + + 1. NO "FILTER (WHERE ...)" clause. Use CASE expressions instead: + WRONG: COUNT(*) FILTER (WHERE email IS NULL) + RIGHT: COUNT(CASE WHEN email IS NULL THEN 1 END) + + 2. Regex matching uses the tilde operator: + column ~ 'pattern' (match) + column !~ 'pattern' (no match) + Do NOT use SIMILAR TO or POSIX-style ~* (case-insensitive). + + 3. CAST syntax — prefer CAST(expr AS type) over :: shorthand. + + 4. String functions: Use CHAR_LENGTH (not LEN), SUBSTR (not SUBSTRING with FROM/FOR). + + 5. Integer division: DataFusion performs integer division for INT/INT. + Use CAST(expr AS DOUBLE) to force floating-point division. + + 6. Boolean literals: Use TRUE / FALSE (not 1 / 0). + + 7. LIMIT is supported. OFFSET is supported. FETCH FIRST is NOT supported. + + 8. NULL handling: COALESCE, NULLIF, IFNULL are all supported. + NVL and ISNULL are NOT supported. +""" + +_UNEXPECTED_QUERY_PROMPT_SECTION = """\ + +UNEXPECTED VALUE COLLECTION: + For checks whose check_category is "validity" or "string_format", also + generate an unexpected_query field on the DQCheck. This query must: + - SELECT the primary key column(s) and the column(s) being validated + - WHERE the row violates the check condition (the negation of the check) + - LIMIT {sample_size} + - Use {dialect} syntax + - Be a standalone SELECT (not a subquery of the group query) + + For all other categories (null_check, uniqueness, numeric_range, row_count), + set unexpected_query to null — these are aggregate checks where individual + violating rows are not meaningful. + + Example for a phone-format validity check: + unexpected_query: "SELECT id, phone FROM customers WHERE phone !~ '^\\d{{4}}-\\d{{4}}-\\d{{4}}$' LIMIT 100" +""" + +_ROW_LEVEL_PROMPT_SECTION = """ + +ROW-LEVEL CHECKS: + Some checks are marked as row_level. For these: + - Generate a SELECT that returns the primary key column(s) and the column + being validated. Do NOT aggregate. + - Set row_level = true on the DQCheck entry. + - metric_key must be the name of the column containing the value to validate + (the Python validator will read row[metric_key] for each row). + - {row_level_limit_clause} + - Place ALL row-level checks for the same table in a single group. + + Row-level check names that require this treatment: {row_level_check_names} +""" + + +class SQLDQPlanner: + """ + Generates and executes a SQL-based :class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`. + + :param llm_hook: Hook used to call the LLM for plan generation. + :param db_hook: Hook used to execute generated SQL against the database. + :param dialect: SQL dialect forwarded to the LLM prompt and ``validate_sql``. + Auto-detected from *db_hook* when ``None``. + :param max_sql_retries: Maximum number of times a failing SQL group query is sent + back to the LLM for correction before the error is re-raised. Default ``2``. + :param validator_contexts: Pre-built LLM context string from + :meth:`~airflow.providers.common.ai.utils.dq_validation.ValidatorRegistry.build_llm_context`. + Appended to the system prompt so the LLM knows what metric format each + custom validator expects. + :param row_validators: Mapping of ``{check_name: row_level_callable}`` for + checks that require row-by-row Python validation. When a check's name + appears here, ``execute_plan`` fetches all (or sampled) rows and applies + the callable to each value instead of reading a single aggregate scalar. + :param row_level_sample_size: Maximum number of rows to fetch for row-level + checks. ``None`` (default) performs a full scan. A positive integer + instructs the LLM to add ``LIMIT N`` to the generated SELECT. + """ + + def __init__( + self, + *, + llm_hook: PydanticAIHook, + db_hook: DbApiHook | None, + dialect: str | None = None, + max_sql_retries: int = 2, + datasource_config: DataSourceConfig | None = None, + system_prompt: str = "", + agent_params: dict[str, Any] | None = None, + collect_unexpected: bool = False, + unexpected_sample_size: int = 100, + validator_contexts: str = "", + row_validators: dict[str, Any] | None = None, + row_level_sample_size: int | None = None, + ) -> None: + self._llm_hook = llm_hook + self._db_hook = db_hook + self._datasource_config = datasource_config + self._dialect = resolve_dialect(db_hook, dialect) + # Track whether the execution target is DataFusion so the prompt can + # include DataFusion-specific syntax rules. The dialect stays None + # (generic SQL) for sqlglot validation — sqlglot has no DataFusion dialect. + self._is_datafusion = db_hook is None and datasource_config is not None + # When targeting DataFusion, use PostgreSQL dialect for sqlglot validation + # because DataFusion shares regex operators (~, !~) that the generic SQL + # parser does not recognise. + self._validation_dialect: str | None = "postgres" if self._is_datafusion else self._dialect + self._max_sql_retries = max_sql_retries + self._extra_system_prompt = system_prompt + self._agent_params: dict[str, Any] = agent_params or {} + self._collect_unexpected = collect_unexpected + self._unexpected_sample_size = unexpected_sample_size + self._validator_contexts = validator_contexts + self._row_validators: dict[str, Any] = row_validators or {} + self._row_level_sample_size = row_level_sample_size + self._cached_datafusion_engine: DataFusionEngine | None = None + self._plan_agent: Agent[None, DQPlan] | None = None + self._plan_all_messages: list[ModelMessage] | None = None + + def build_schema_context( + self, + table_names: list[str] | None, + schema_context: str | None, + ) -> str: + """ + Return a schema description string for inclusion in the LLM prompt. + + Delegates to :func:`~airflow.providers.common.ai.utils.db_schema.build_schema_context`. + """ + return build_schema_context( + db_hook=self._db_hook, + table_names=table_names, + schema_context=schema_context, + datasource_config=self._datasource_config, + ) + + def generate_plan(self, prompts: dict[str, str], schema_context: str) -> DQPlan: + """ + Ask the LLM to produce a :class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`. + + The LLM receives the user prompts, schema context, and planning instructions + as a structured-output call (``output_type=DQPlan``). After generation the + method verifies that the returned ``check_names`` exactly match + ``prompts.keys()``. + + :param prompts: ``{check_name: natural_language_description}`` dict. + :param schema_context: Schema description previously built via + :meth:`build_schema_context`. + :raises ValueError: If the LLM's plan does not cover every prompt key + exactly once. + """ + dialect_label = self._dialect or ("DataFusion-compatible SQL" if self._is_datafusion else "SQL") + system_prompt = _PLANNING_SYSTEM_PROMPT.format( + dialect=dialect_label, max_checks_per_group=_MAX_CHECKS_PER_GROUP + ) + + if self._is_datafusion: + system_prompt += _DATAFUSION_SYNTAX_SECTION + + if self._collect_unexpected: + system_prompt += _UNEXPECTED_QUERY_PROMPT_SECTION.format( + dialect=dialect_label, sample_size=self._unexpected_sample_size + ) + + if schema_context: + system_prompt += f"\nAvailable schema:\n{schema_context}\n" + + if self._validator_contexts: + system_prompt += self._validator_contexts + + if self._row_validators: + row_level_check_names = ", ".join(sorted(self._row_validators)) + if self._row_level_sample_size is not None: + limit_clause = f"Add LIMIT {self._row_level_sample_size} to the query." + else: + limit_clause = "Do NOT add a LIMIT — return all rows." + system_prompt += _ROW_LEVEL_PROMPT_SECTION.format( + row_level_check_names=row_level_check_names, + row_level_limit_clause=limit_clause, + ) + + if self._extra_system_prompt: + system_prompt += f"\nAdditional instructions:\n{self._extra_system_prompt}\n" + + user_message = self._build_user_message(prompts) + + log.info("Using system prompt:\n%s", system_prompt) + log.info("Using user message:\n%s", user_message) + Review Comment: I've set logging to Debug mode, but I don't think it's a problem if it prints as info on the first run, since the process continues without generating a context after the cache is loaded. I've set it to debug mode for now, but we can switch it back to info depending on the situation. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
