From a36066d24611528cb319f9d929d4278ec37c7d21 Mon Sep 17 00:00:00 2001
From: Anthonin Bonnefoy <anthonin.bonnefoy@datadoghq.com>
Date: Thu, 3 Oct 2024 08:52:02 +0200
Subject: Track location to extract relevant part in nested statement

Previously, Query generated through transform would have unset
stmt_location. Extensions relying on the statement location to extract
the relevant part of the statement would fallback to use the whole
statement instead, thus showing the same string in the top and
nested level which was a source of confusion.

This patch fixes the issue by keeping track of the statement locations
and propagate it to Query during transform, allowing pgss to only show
the relevant part of the query for nested query.
---
 .../expected/level_tracking.out               | 159 ++++++++++++++++++
 .../pg_stat_statements/sql/level_tracking.sql |  26 +++
 src/backend/parser/analyze.c                  |  10 ++
 src/backend/parser/gram.y                     |  17 +-
 src/backend/parser/parse_merge.c              |   2 +
 src/include/nodes/parsenodes.h                |   5 +
 6 files changed, 214 insertions(+), 5 deletions(-)

diff --git a/contrib/pg_stat_statements/expected/level_tracking.out b/contrib/pg_stat_statements/expected/level_tracking.out
index bb65e98ce09..5649c0b7c10 100644
--- a/contrib/pg_stat_statements/expected/level_tracking.out
+++ b/contrib/pg_stat_statements/expected/level_tracking.out
@@ -67,6 +67,165 @@ SELECT toplevel, calls, query FROM pg_stat_statements
  t        |     1 | SET pg_stat_statements.track = $1
 (7 rows)
 
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
+ t 
+---
+ t
+(1 row)
+
+-- Explain - all-level tracking.
+SET pg_stat_statements.track = 'all';
+explain (costs off) SELECT 1;
+ QUERY PLAN 
+------------
+ Result
+(1 row)
+
+explain (costs off) UPDATE stats_track_tab SET x=1 WHERE x=1;
+            QUERY PLAN             
+-----------------------------------
+ Update on stats_track_tab
+   ->  Seq Scan on stats_track_tab
+         Filter: (x = 1)
+(3 rows)
+
+explain (costs off) DELETE FROM stats_track_tab;
+            QUERY PLAN             
+-----------------------------------
+ Delete on stats_track_tab
+   ->  Seq Scan on stats_track_tab
+(2 rows)
+
+explain (costs off) INSERT INTO stats_track_tab VALUES ((1));
+        QUERY PLAN         
+---------------------------
+ Insert on stats_track_tab
+   ->  Result
+(2 rows)
+
+explain (costs off) MERGE INTO stats_track_tab USING (SELECT id FROM generate_series(1, 10) id) ON x = id
+    WHEN MATCHED THEN UPDATE SET x = id
+    WHEN NOT MATCHED THEN INSERT (x) VALUES (id);
+                      QUERY PLAN                       
+-------------------------------------------------------
+ Merge on stats_track_tab
+   ->  Hash Right Join
+         Hash Cond: (stats_track_tab.x = id.id)
+         ->  Seq Scan on stats_track_tab
+         ->  Hash
+               ->  Function Scan on generate_series id
+(6 rows)
+
+explain (costs off) SELECT 1 UNION SELECT 2;
+        QUERY PLAN        
+--------------------------
+ Unique
+   ->  Sort
+         Sort Key: (1)
+         ->  Append
+               ->  Result
+               ->  Result
+(6 rows)
+
+-- Check we correctly capture substring with CTE
+explain (costs off) WITH a AS (select 4) SELECT 1;
+ QUERY PLAN 
+------------
+ Result
+(1 row)
+
+explain (costs off) WITH a AS (select 4) UPDATE stats_track_tab SET x=1 WHERE x=1;
+            QUERY PLAN             
+-----------------------------------
+ Update on stats_track_tab
+   ->  Seq Scan on stats_track_tab
+         Filter: (x = 1)
+(3 rows)
+
+explain (costs off) WITH a AS (select 4) DELETE FROM stats_track_tab;
+            QUERY PLAN             
+-----------------------------------
+ Delete on stats_track_tab
+   ->  Seq Scan on stats_track_tab
+(2 rows)
+
+explain (costs off) WITH a AS (select 4) INSERT INTO stats_track_tab VALUES ((1));
+        QUERY PLAN         
+---------------------------
+ Insert on stats_track_tab
+   ->  Result
+(2 rows)
+
+explain (costs off) WITH a AS (select 4) MERGE INTO stats_track_tab USING (SELECT id FROM generate_series(1, 10) id) ON x = id
+    WHEN MATCHED THEN UPDATE SET x = id
+    WHEN NOT MATCHED THEN INSERT (x) VALUES (id);
+                      QUERY PLAN                       
+-------------------------------------------------------
+ Merge on stats_track_tab
+   ->  Hash Right Join
+         Hash Cond: (stats_track_tab.x = id.id)
+         ->  Seq Scan on stats_track_tab
+         ->  Hash
+               ->  Function Scan on generate_series id
+(6 rows)
+
+explain (costs off) WITH a AS (select 4) SELECT 1 UNION SELECT 2;
+        QUERY PLAN        
+--------------------------
+ Unique
+   ->  Sort
+         Sort Key: (1)
+         ->  Append
+               ->  Result
+               ->  Result
+(6 rows)
+
+SELECT toplevel, calls, query FROM pg_stat_statements
+  ORDER BY query COLLATE "C", toplevel;
+ toplevel | calls |                                                              query                                                               
+----------+-------+----------------------------------------------------------------------------------------------------------------------------------
+ f        |     1 | DELETE FROM stats_track_tab;
+ f        |     1 | INSERT INTO stats_track_tab VALUES (($1));
+ f        |     1 | MERGE INTO stats_track_tab USING (SELECT id FROM generate_series($1, $2) id) ON x = id                                          +
+          |       |     WHEN MATCHED THEN UPDATE SET x = id                                                                                         +
+          |       |     WHEN NOT MATCHED THEN INSERT (x) VALUES (id);
+ f        |     1 | SELECT $1 UNION SELECT $2;
+ f        |     1 | SELECT $1;
+ t        |     1 | SELECT pg_stat_statements_reset() IS NOT NULL AS t
+ t        |     1 | SET pg_stat_statements.track = $1
+ f        |     1 | UPDATE stats_track_tab SET x=$1 WHERE x=$2;
+ f        |     1 | WITH a AS (select $1) DELETE FROM stats_track_tab;
+ f        |     1 | WITH a AS (select $1) INSERT INTO stats_track_tab VALUES (($2));
+ f        |     1 | WITH a AS (select $1) MERGE INTO stats_track_tab USING (SELECT id FROM generate_series($2, $3) id) ON x = id                    +
+          |       |     WHEN MATCHED THEN UPDATE SET x = id                                                                                         +
+          |       |     WHEN NOT MATCHED THEN INSERT (x) VALUES (id);
+ f        |     1 | WITH a AS (select $1) SELECT $2 UNION SELECT $3;
+ f        |     1 | WITH a AS (select $1) SELECT $2;
+ f        |     1 | WITH a AS (select $1) UPDATE stats_track_tab SET x=$2 WHERE x=$3;
+ t        |     1 | explain (costs off) DELETE FROM stats_track_tab
+ t        |     1 | explain (costs off) INSERT INTO stats_track_tab VALUES (($1))
+ t        |     1 | explain (costs off) MERGE INTO stats_track_tab USING (SELECT id FROM generate_series($1, $2) id) ON x = id                      +
+          |       |     WHEN MATCHED THEN UPDATE SET x = id                                                                                         +
+          |       |     WHEN NOT MATCHED THEN INSERT (x) VALUES (id)
+ t        |     1 | explain (costs off) SELECT $1
+ t        |     1 | explain (costs off) SELECT $1 UNION SELECT $2
+ t        |     1 | explain (costs off) UPDATE stats_track_tab SET x=$1 WHERE x=$2
+ t        |     1 | explain (costs off) WITH a AS (select $1) DELETE FROM stats_track_tab
+ t        |     1 | explain (costs off) WITH a AS (select $1) INSERT INTO stats_track_tab VALUES (($2))
+ t        |     1 | explain (costs off) WITH a AS (select $1) MERGE INTO stats_track_tab USING (SELECT id FROM generate_series($2, $3) id) ON x = id+
+          |       |     WHEN MATCHED THEN UPDATE SET x = id                                                                                         +
+          |       |     WHEN NOT MATCHED THEN INSERT (x) VALUES (id)
+ t        |     1 | explain (costs off) WITH a AS (select $1) SELECT $2
+ t        |     1 | explain (costs off) WITH a AS (select $1) SELECT $2 UNION SELECT $3
+ t        |     1 | explain (costs off) WITH a AS (select $1) UPDATE stats_track_tab SET x=$2 WHERE x=$3
+(26 rows)
+
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
+ t 
+---
+ t
+(1 row)
+
 -- Procedure with multiple utility statements.
 CREATE OR REPLACE PROCEDURE proc_with_utility_stmt()
 LANGUAGE SQL
diff --git a/contrib/pg_stat_statements/sql/level_tracking.sql b/contrib/pg_stat_statements/sql/level_tracking.sql
index 65a17147a5a..1b431c9928c 100644
--- a/contrib/pg_stat_statements/sql/level_tracking.sql
+++ b/contrib/pg_stat_statements/sql/level_tracking.sql
@@ -32,6 +32,32 @@ BEGIN
 END; $$;
 SELECT toplevel, calls, query FROM pg_stat_statements
   ORDER BY query COLLATE "C", toplevel;
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
+
+-- Explain - all-level tracking.
+SET pg_stat_statements.track = 'all';
+explain (costs off) SELECT 1;
+explain (costs off) UPDATE stats_track_tab SET x=1 WHERE x=1;
+explain (costs off) DELETE FROM stats_track_tab;
+explain (costs off) INSERT INTO stats_track_tab VALUES ((1));
+explain (costs off) MERGE INTO stats_track_tab USING (SELECT id FROM generate_series(1, 10) id) ON x = id
+    WHEN MATCHED THEN UPDATE SET x = id
+    WHEN NOT MATCHED THEN INSERT (x) VALUES (id);
+explain (costs off) SELECT 1 UNION SELECT 2;
+
+-- Check we correctly capture substring with CTE
+explain (costs off) WITH a AS (select 4) SELECT 1;
+explain (costs off) WITH a AS (select 4) UPDATE stats_track_tab SET x=1 WHERE x=1;
+explain (costs off) WITH a AS (select 4) DELETE FROM stats_track_tab;
+explain (costs off) WITH a AS (select 4) INSERT INTO stats_track_tab VALUES ((1));
+explain (costs off) WITH a AS (select 4) MERGE INTO stats_track_tab USING (SELECT id FROM generate_series(1, 10) id) ON x = id
+    WHEN MATCHED THEN UPDATE SET x = id
+    WHEN NOT MATCHED THEN INSERT (x) VALUES (id);
+explain (costs off) WITH a AS (select 4) SELECT 1 UNION SELECT 2;
+
+SELECT toplevel, calls, query FROM pg_stat_statements
+  ORDER BY query COLLATE "C", toplevel;
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
 
 -- Procedure with multiple utility statements.
 CREATE OR REPLACE PROCEDURE proc_with_utility_stmt()
diff --git a/src/backend/parser/analyze.c b/src/backend/parser/analyze.c
index e901203424d..1cf489eb3a0 100644
--- a/src/backend/parser/analyze.c
+++ b/src/backend/parser/analyze.c
@@ -518,6 +518,7 @@ transformDeleteStmt(ParseState *pstate, DeleteStmt *stmt)
 	Node	   *qual;
 
 	qry->commandType = CMD_DELETE;
+	qry->stmt_location = stmt->location;
 
 	/* process the WITH clause independently of all else */
 	if (stmt->withClause)
@@ -525,6 +526,7 @@ transformDeleteStmt(ParseState *pstate, DeleteStmt *stmt)
 		qry->hasRecursive = stmt->withClause->recursive;
 		qry->cteList = transformWithClause(pstate, stmt->withClause);
 		qry->hasModifyingCTE = pstate->p_hasModifyingCTE;
+		qry->stmt_location = stmt->withClause->location;
 	}
 
 	/* set up range table with just the result rel */
@@ -606,6 +608,7 @@ transformInsertStmt(ParseState *pstate, InsertStmt *stmt)
 	Assert(pstate->p_ctenamespace == NIL);
 
 	qry->commandType = CMD_INSERT;
+	qry->stmt_location = stmt->location;
 	pstate->p_is_insert = true;
 
 	/* process the WITH clause independently of all else */
@@ -614,6 +617,7 @@ transformInsertStmt(ParseState *pstate, InsertStmt *stmt)
 		qry->hasRecursive = stmt->withClause->recursive;
 		qry->cteList = transformWithClause(pstate, stmt->withClause);
 		qry->hasModifyingCTE = pstate->p_hasModifyingCTE;
+		qry->stmt_location = stmt->withClause->location;
 	}
 
 	qry->override = stmt->override;
@@ -1347,6 +1351,7 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
 	ListCell   *l;
 
 	qry->commandType = CMD_SELECT;
+	qry->stmt_location = stmt->location;
 
 	/* process the WITH clause independently of all else */
 	if (stmt->withClause)
@@ -1354,6 +1359,7 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
 		qry->hasRecursive = stmt->withClause->recursive;
 		qry->cteList = transformWithClause(pstate, stmt->withClause);
 		qry->hasModifyingCTE = pstate->p_hasModifyingCTE;
+		qry->stmt_location = stmt->withClause->location;
 	}
 
 	/* Complain if we get called from someplace where INTO is not allowed */
@@ -1730,6 +1736,7 @@ transformSetOperationStmt(ParseState *pstate, SelectStmt *stmt)
 	int			tllen;
 
 	qry->commandType = CMD_SELECT;
+	qry->stmt_location = stmt->location;
 
 	/*
 	 * Find leftmost leaf SelectStmt.  We currently only need to do this in
@@ -1784,6 +1791,7 @@ transformSetOperationStmt(ParseState *pstate, SelectStmt *stmt)
 		qry->hasRecursive = withClause->recursive;
 		qry->cteList = transformWithClause(pstate, withClause);
 		qry->hasModifyingCTE = pstate->p_hasModifyingCTE;
+		qry->stmt_location = withClause->location;
 	}
 
 	/*
@@ -2429,6 +2437,7 @@ transformUpdateStmt(ParseState *pstate, UpdateStmt *stmt)
 	Node	   *qual;
 
 	qry->commandType = CMD_UPDATE;
+	qry->stmt_location = stmt->location;
 	pstate->p_is_insert = false;
 
 	/* process the WITH clause independently of all else */
@@ -2437,6 +2446,7 @@ transformUpdateStmt(ParseState *pstate, UpdateStmt *stmt)
 		qry->hasRecursive = stmt->withClause->recursive;
 		qry->cteList = transformWithClause(pstate, stmt->withClause);
 		qry->hasModifyingCTE = pstate->p_hasModifyingCTE;
+		qry->stmt_location = stmt->withClause->location;
 	}
 
 	qry->resultRelation = setTargetTable(pstate, stmt->relation,
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 4aa8646af7b..dcb2588b396 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -190,7 +190,7 @@ static void insertSelectOptions(SelectStmt *stmt,
 								SelectLimit *limitClause,
 								WithClause *withClause,
 								core_yyscan_t yyscanner);
-static Node *makeSetOp(SetOperation op, bool all, Node *larg, Node *rarg);
+static Node *makeSetOp(SetOperation op, bool all, Node *larg, Node *rarg, int location);
 static Node *doNegate(Node *n, int location);
 static void doNegateFloat(Float *v);
 static Node *makeAndExpr(Node *lexpr, Node *rexpr, int location);
@@ -12170,6 +12170,7 @@ InsertStmt:
 					$5->onConflictClause = $6;
 					$5->returningList = $7;
 					$5->withClause = $1;
+					$5->location = @2;
 					$$ = (Node *) $5;
 				}
 		;
@@ -12323,6 +12324,7 @@ DeleteStmt: opt_with_clause DELETE_P FROM relation_expr_opt_alias
 					n->whereClause = $6;
 					n->returningList = $7;
 					n->withClause = $1;
+					n->location = @2;
 					$$ = (Node *) n;
 				}
 		;
@@ -12397,6 +12399,7 @@ UpdateStmt: opt_with_clause UPDATE relation_expr_opt_alias
 					n->whereClause = $7;
 					n->returningList = $8;
 					n->withClause = $1;
+					n->location = @2;
 					$$ = (Node *) n;
 				}
 		;
@@ -12474,6 +12477,7 @@ MergeStmt:
 					m->joinCondition = $8;
 					m->mergeWhenClauses = $9;
 					m->returningList = $10;
+					m->location = @2;
 
 					$$ = (Node *) m;
 				}
@@ -12836,6 +12840,7 @@ simple_select:
 					n->groupDistinct = ($7)->distinct;
 					n->havingClause = $8;
 					n->windowClause = $9;
+					n->location = @1;
 					$$ = (Node *) n;
 				}
 			| SELECT distinct_clause target_list
@@ -12853,6 +12858,7 @@ simple_select:
 					n->groupDistinct = ($7)->distinct;
 					n->havingClause = $8;
 					n->windowClause = $9;
+					n->location = @1;
 					$$ = (Node *) n;
 				}
 			| values_clause							{ $$ = $1; }
@@ -12877,15 +12883,15 @@ simple_select:
 				}
 			| select_clause UNION set_quantifier select_clause
 				{
-					$$ = makeSetOp(SETOP_UNION, $3 == SET_QUANTIFIER_ALL, $1, $4);
+					$$ = makeSetOp(SETOP_UNION, $3 == SET_QUANTIFIER_ALL, $1, $4, @1);
 				}
 			| select_clause INTERSECT set_quantifier select_clause
 				{
-					$$ = makeSetOp(SETOP_INTERSECT, $3 == SET_QUANTIFIER_ALL, $1, $4);
+					$$ = makeSetOp(SETOP_INTERSECT, $3 == SET_QUANTIFIER_ALL, $1, $4, @1);
 				}
 			| select_clause EXCEPT set_quantifier select_clause
 				{
-					$$ = makeSetOp(SETOP_EXCEPT, $3 == SET_QUANTIFIER_ALL, $1, $4);
+					$$ = makeSetOp(SETOP_EXCEPT, $3 == SET_QUANTIFIER_ALL, $1, $4, @1);
 				}
 		;
 
@@ -18967,7 +18973,7 @@ insertSelectOptions(SelectStmt *stmt,
 }
 
 static Node *
-makeSetOp(SetOperation op, bool all, Node *larg, Node *rarg)
+makeSetOp(SetOperation op, bool all, Node *larg, Node *rarg, int location)
 {
 	SelectStmt *n = makeNode(SelectStmt);
 
@@ -18975,6 +18981,7 @@ makeSetOp(SetOperation op, bool all, Node *larg, Node *rarg)
 	n->all = all;
 	n->larg = (SelectStmt *) larg;
 	n->rarg = (SelectStmt *) rarg;
+	n->location = location;
 	return (Node *) n;
 }
 
diff --git a/src/backend/parser/parse_merge.c b/src/backend/parser/parse_merge.c
index 87df79027d7..2389cc00513 100644
--- a/src/backend/parser/parse_merge.c
+++ b/src/backend/parser/parse_merge.c
@@ -118,6 +118,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
 	Assert(pstate->p_ctenamespace == NIL);
 
 	qry->commandType = CMD_MERGE;
+	qry->stmt_location = stmt->location;
 	qry->hasRecursive = false;
 
 	/* process the WITH clause independently of all else */
@@ -130,6 +131,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
 
 		qry->cteList = transformWithClause(pstate, stmt->withClause);
 		qry->hasModifyingCTE = pstate->p_hasModifyingCTE;
+		qry->stmt_location = stmt->withClause->location;
 	}
 
 	/*
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index 1c314cd9074..aa73e2b2428 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -2045,6 +2045,7 @@ typedef struct InsertStmt
 	List	   *returningList;	/* list of expressions to return */
 	WithClause *withClause;		/* WITH clause */
 	OverridingKind override;	/* OVERRIDING clause */
+	ParseLoc	location;		/* token location, or -1 if unknown */
 } InsertStmt;
 
 /* ----------------------
@@ -2059,6 +2060,7 @@ typedef struct DeleteStmt
 	Node	   *whereClause;	/* qualifications */
 	List	   *returningList;	/* list of expressions to return */
 	WithClause *withClause;		/* WITH clause */
+	ParseLoc	location;		/* token location, or -1 if unknown */
 } DeleteStmt;
 
 /* ----------------------
@@ -2074,6 +2076,7 @@ typedef struct UpdateStmt
 	List	   *fromClause;		/* optional from clause for more tables */
 	List	   *returningList;	/* list of expressions to return */
 	WithClause *withClause;		/* WITH clause */
+	ParseLoc	location;		/* token location, or -1 if unknown */
 } UpdateStmt;
 
 /* ----------------------
@@ -2089,6 +2092,7 @@ typedef struct MergeStmt
 	List	   *mergeWhenClauses;	/* list of MergeWhenClause(es) */
 	List	   *returningList;	/* list of expressions to return */
 	WithClause *withClause;		/* WITH clause */
+	ParseLoc	location;		/* token location, or -1 if unknown */
 } MergeStmt;
 
 /* ----------------------
@@ -2158,6 +2162,7 @@ typedef struct SelectStmt
 	bool		all;			/* ALL specified? */
 	struct SelectStmt *larg;	/* left child */
 	struct SelectStmt *rarg;	/* right child */
+	ParseLoc	location;		/* name's token location */
 	/* Eventually add fields for CORRESPONDING spec here */
 } SelectStmt;
 
-- 
2.39.3 (Apple Git-146)

