Omega359 commented on code in PR #21707: URL: https://github.com/apache/datafusion/pull/21707#discussion_r3121236785
########## benchmarks/src/sql_benchmark.rs: ########## @@ -0,0 +1,3464 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, RecordBatch}; +use arrow::datatypes::*; +use arrow::error::ArrowError; +use arrow::util::display::{ArrayFormatter, FormatOptions}; +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::datasource::MemTable; +use datafusion::physical_plan::execute_stream; +use datafusion::prelude::{CsvReadOptions, DataFrame, SessionContext}; +use datafusion_common::config::CsvOptions; +use datafusion_common::{DataFusionError, Result, exec_datafusion_err}; +use futures::StreamExt; +use log::{debug, info, trace}; +use regex::Regex; +use std::collections::HashMap; +use std::fmt::Debug; +use std::fs::{self, File, OpenOptions}; +use std::io::{BufRead, BufReader}; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock}; + +/// A collection of benchmark configurations and state used by the DataFusion +/// sql test harness. Each benchmark is defined by a file that can contain +/// directives such as `load`, `run`, `assert`, `result`, etc. The +/// `SqlBenchmark` struct holds the parsed data from that file and +/// the impl provides methods to run, assert, persist, verify and cleanup benchmark +/// results. +#[derive(Debug, Clone)] +pub struct SqlBenchmark { + /// Human‑readable name of the benchmark. + name: String, + /// Top‑level group name (derived from the file path or defined in a benchmark). + group: String, + /// Subgroup name, often a logical grouping. + subgroup: String, + /// Full path to the benchmark file. + benchmark_path: PathBuf, + /// Mapping of placeholder keys to concrete values (e.g. `"BENCHMARK_DIR"`). + replacement_mapping: HashMap<String, String>, + /// Expected string that must appear in the physical plan of the queries. + expect: Vec<String>, + /// All SQL queries grouped by directive (`load`, `run`, etc.). + queries: HashMap<QueryDirective, Vec<String>>, + /// Queries whose results are persisted to disk for later comparison. + result_queries: Vec<BenchmarkQuery>, + /// Queries whose results are asserted against an expected table. + assert_queries: Vec<BenchmarkQuery>, + /// Flag indicating whether the benchmark has been fully loaded + is_loaded: bool, + /// Stores the last run results if needed so they can be compared or persisted. + last_results: Option<Vec<RecordBatch>>, + /// echo statements + echo: Vec<String>, +} + +impl SqlBenchmark { + pub async fn new( + ctx: &SessionContext, + full_path: impl AsRef<Path>, + benchmark_directory: impl AsRef<Path>, + ) -> Result<Self> { + let full_path = full_path.as_ref(); + let benchmark_directory = benchmark_directory.as_ref(); + let group_name = parse_group_from_path(full_path, benchmark_directory); + let mut bm = Self { + name: String::new(), + group: group_name, + subgroup: String::new(), + benchmark_path: full_path.to_path_buf(), + replacement_mapping: HashMap::new(), + expect: vec![], + queries: HashMap::new(), + result_queries: vec![], + assert_queries: vec![], + is_loaded: false, + last_results: None, + echo: vec![], + }; + bm.replacement_mapping.insert( + "BENCHMARK_DIR".to_string(), + benchmark_directory.to_string_lossy().into_owned(), + ); + + let path = bm.benchmark_path.clone(); + bm.process_file(ctx, &path).await?; + + Ok(bm) + } + + /// Initializes the benchmark by executing `load` and `init` queries. + /// + /// Registers any required tables or sets up state in the provided + /// `SessionContext` before running queries. This method is idempotent: + /// calling it multiple times on the same instance returns + /// immediately after the first successful initialization. + /// + /// # Errors + /// Returns an error if any `load` or `init` query fails, or if the + /// benchmark file does not contain a `run` query. + pub async fn initialize(&mut self, ctx: &SessionContext) -> Result<()> { + if self.is_loaded { + return Ok(()); + } + + let path = self.benchmark_path.to_string_lossy().into_owned(); + + // validate there was a run query + if !self.queries.contains_key(&QueryDirective::Run) { + return Err(exec_datafusion_err!( + "Invalid benchmark file: no \"run\" query specified: {path}" + )); + } + + // display any echo's + self.echo.iter().for_each(|txt| println!("{txt}")); + + let load_queries = self.queries.get(&QueryDirective::Load); + + if let Some(queries) = load_queries { + for query in queries { + debug!("Executing load query {query}"); + ctx.sql(query).await?.collect().await?; + } + } + + let init_queries = self.queries.get(&QueryDirective::Init); + + if let Some(queries) = init_queries { + for query in queries { + debug!("Executing init query {query}"); + ctx.sql(query).await?.collect().await?; + } + } + + self.is_loaded = true; + + Ok(()) + } + + /// Executes the `assert` queries and compares actual results against + /// expected values. + /// + /// Each `assert` query must be followed by a result table (separated by + /// `----`) in the benchmark file. The assertion passes only if the + /// returned record batches exactly match the expected rows. + /// + /// # Errors + /// Returns an error if any `assert` query fails, or if the actual and + /// expected results differ in row count or cell values. + pub async fn assert(&mut self, ctx: &SessionContext) -> Result<()> { + info!("Running assertions..."); + + for assert_query in &self.assert_queries { + let query = &assert_query.query; + + info!("Executing assert query {query}"); + + let result = ctx.sql(query).await?.collect().await?; + let formatted_actual_results = format_record_batches(&result)?; + + Self::compare_results( + assert_query, + &formatted_actual_results, + &assert_query.expected_result, + )?; + } + + Ok(()) + } + + /// Executes the `run` queries, optionally saving results for later + /// verification. If there are multiple queries only the results for + /// the last query are saved. + /// + /// When `save_results` is `true`, it collects `SELECT`/`WITH` query + /// results and stores them in `last_results`. + /// + /// When `save_results` is `false`, it streams results and counts rows + /// without buffering them. + /// + /// If an 'expect' string is defined this method also validates that + /// the physical plan contains that string. + /// + /// # Errors + /// Returns an error if a `run` query fails or if expected plan strings + /// are not found. + pub async fn run(&mut self, ctx: &SessionContext, save_results: bool) -> Result<()> { + let run_queries = self + .queries + .get(&QueryDirective::Run) + .ok_or_else(|| exec_datafusion_err!("Run query should be loaded by now"))?; + + let mut result_count = 0; + + let result: Vec<RecordBatch> = { + let mut local_result = vec![]; + + for query in run_queries { + match save_results { + true => { + debug!( + "Running query (saving results) {}-{}: {query}", + self.group, self.subgroup + ); + + let df = ctx.sql(query).await?; + if !self.expect.is_empty() { + let physical_plan = df.create_physical_plan().await?; + self.validate_expected_plan(&physical_plan)?; + } + + let result_schema = Arc::new(df.schema().as_arrow().clone()); + let mut batches = df.collect().await?; + let trimmed = query.trim_start(); + + // save the output for select/with queries + if starts_with_ignore_ascii_case(trimmed, "select") + || starts_with_ignore_ascii_case(trimmed, "with") + { + if batches.is_empty() { + batches.push(RecordBatch::new_empty(result_schema)); + } + let row_count_for_query = + batches.iter().map(RecordBatch::num_rows).sum::<usize>(); + debug!( + "Persisting {} batches ({} rows)...", + batches.len(), + row_count_for_query + ); + + result_count = row_count_for_query; + local_result = batches; + } + } + false => { + debug!( + "Running query (ignoring results) {}-{}: {query}", + self.group, self.subgroup + ); + + result_count = self + .execute_sql_without_result_buffering(query, ctx) + .await?; + } + } + } + + Ok::<Vec<RecordBatch>, DataFusionError>(local_result) + }?; + + debug!("Results have {result_count} rows"); + + // Store results for verification + self.last_results = Some(result); Review Comment: technically, yes though that path would never occur since persist is only called if save_results is true. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
