From 490e1f8310c1dae00eb7f42b5dfbbfbc7f0bc763 Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Tue, 1 Aug 2023 11:44:01 +0800
Subject: [PATCH v1] Fix reparameterize_path_by_child for SampleScan

---
 src/backend/optimizer/path/allpaths.c        |  8 +--
 src/backend/optimizer/path/costsize.c        | 20 +++----
 src/backend/optimizer/plan/createplan.c      | 18 +++---
 src/backend/optimizer/util/pathnode.c        | 47 +++++++++------
 src/include/nodes/pathnodes.h                | 11 ++++
 src/include/optimizer/cost.h                 |  4 +-
 src/include/optimizer/pathnode.h             |  4 +-
 src/test/regress/expected/partition_join.out | 60 ++++++++++++++++++++
 src/test/regress/sql/partition_join.sql      | 12 ++++
 9 files changed, 139 insertions(+), 45 deletions(-)

diff --git a/src/backend/optimizer/path/allpaths.c b/src/backend/optimizer/path/allpaths.c
index 9bdc70c702..c63f872795 100644
--- a/src/backend/optimizer/path/allpaths.c
+++ b/src/backend/optimizer/path/allpaths.c
@@ -867,7 +867,7 @@ set_tablesample_rel_pathlist(PlannerInfo *root, RelOptInfo *rel, RangeTblEntry *
 	required_outer = rel->lateral_relids;
 
 	/* Consider sampled scan */
-	path = create_samplescan_path(root, rel, required_outer);
+	path = (Path *) create_samplescan_path(root, rel, required_outer);
 
 	/*
 	 * If the sampling method does not support repeatable scans, we must avoid
@@ -4426,9 +4426,6 @@ print_path(PlannerInfo *root, Path *path, int indent)
 				case T_SeqScan:
 					ptype = "SeqScan";
 					break;
-				case T_SampleScan:
-					ptype = "SampleScan";
-					break;
 				case T_FunctionScan:
 					ptype = "FunctionScan";
 					break;
@@ -4455,6 +4452,9 @@ print_path(PlannerInfo *root, Path *path, int indent)
 					break;
 			}
 			break;
+		case T_SampleScanPath:
+			ptype = "SampleScan";
+			break;
 		case T_IndexPath:
 			ptype = "IdxScan";
 			break;
diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c
index ef475d95a1..6c7a5d2627 100644
--- a/src/backend/optimizer/path/costsize.c
+++ b/src/backend/optimizer/path/costsize.c
@@ -330,12 +330,11 @@ cost_seqscan(Path *path, PlannerInfo *root,
  * 'param_info' is the ParamPathInfo if this is a parameterized path, else NULL
  */
 void
-cost_samplescan(Path *path, PlannerInfo *root,
+cost_samplescan(SampleScanPath *path, PlannerInfo *root,
 				RelOptInfo *baserel, ParamPathInfo *param_info)
 {
 	Cost		startup_cost = 0;
 	Cost		run_cost = 0;
-	RangeTblEntry *rte;
 	TableSampleClause *tsc;
 	TsmRoutine *tsm;
 	double		spc_seq_page_cost,
@@ -346,17 +345,16 @@ cost_samplescan(Path *path, PlannerInfo *root,
 
 	/* Should only be applied to base relations with tablesample clauses */
 	Assert(baserel->relid > 0);
-	rte = planner_rt_fetch(baserel->relid, root);
-	Assert(rte->rtekind == RTE_RELATION);
-	tsc = rte->tablesample;
+	Assert(baserel->rtekind == RTE_RELATION);
+	tsc = path->tablesample;
 	Assert(tsc != NULL);
 	tsm = GetTsmRoutine(tsc->tsmhandler);
 
 	/* Mark the path with the correct row estimate */
 	if (param_info)
-		path->rows = param_info->ppi_rows;
+		path->path.rows = param_info->ppi_rows;
 	else
-		path->rows = baserel->rows;
+		path->path.rows = baserel->rows;
 
 	/* fetch estimated page cost for tablespace containing table */
 	get_tablespace_page_costs(baserel->reltablespace,
@@ -387,11 +385,11 @@ cost_samplescan(Path *path, PlannerInfo *root,
 	cpu_per_tuple = cpu_tuple_cost + qpqual_cost.per_tuple;
 	run_cost += cpu_per_tuple * baserel->tuples;
 	/* tlist eval costs are paid per output row, not per tuple scanned */
-	startup_cost += path->pathtarget->cost.startup;
-	run_cost += path->pathtarget->cost.per_tuple * path->rows;
+	startup_cost += path->path.pathtarget->cost.startup;
+	run_cost += path->path.pathtarget->cost.per_tuple * path->path.rows;
 
-	path->startup_cost = startup_cost;
-	path->total_cost = startup_cost + run_cost;
+	path->path.startup_cost = startup_cost;
+	path->path.total_cost = startup_cost + run_cost;
 }
 
 /*
diff --git a/src/backend/optimizer/plan/createplan.c b/src/backend/optimizer/plan/createplan.c
index af48109058..073488b20c 100644
--- a/src/backend/optimizer/plan/createplan.c
+++ b/src/backend/optimizer/plan/createplan.c
@@ -120,7 +120,7 @@ static Limit *create_limit_plan(PlannerInfo *root, LimitPath *best_path,
 								int flags);
 static SeqScan *create_seqscan_plan(PlannerInfo *root, Path *best_path,
 									List *tlist, List *scan_clauses);
-static SampleScan *create_samplescan_plan(PlannerInfo *root, Path *best_path,
+static SampleScan *create_samplescan_plan(PlannerInfo *root, SampleScanPath *best_path,
 										  List *tlist, List *scan_clauses);
 static Scan *create_indexscan_plan(PlannerInfo *root, IndexPath *best_path,
 								   List *tlist, List *scan_clauses, bool indexonly);
@@ -664,7 +664,7 @@ create_scan_plan(PlannerInfo *root, Path *best_path, int flags)
 
 		case T_SampleScan:
 			plan = (Plan *) create_samplescan_plan(root,
-												   best_path,
+												   (SampleScanPath *) best_path,
 												   tlist,
 												   scan_clauses);
 			break;
@@ -2929,19 +2929,17 @@ create_seqscan_plan(PlannerInfo *root, Path *best_path,
  *	 with restriction clauses 'scan_clauses' and targetlist 'tlist'.
  */
 static SampleScan *
-create_samplescan_plan(PlannerInfo *root, Path *best_path,
+create_samplescan_plan(PlannerInfo *root, SampleScanPath *best_path,
 					   List *tlist, List *scan_clauses)
 {
 	SampleScan *scan_plan;
-	Index		scan_relid = best_path->parent->relid;
-	RangeTblEntry *rte;
+	Index		scan_relid = best_path->path.parent->relid;
 	TableSampleClause *tsc;
 
 	/* it should be a base rel with a tablesample clause... */
 	Assert(scan_relid > 0);
-	rte = planner_rt_fetch(scan_relid, root);
-	Assert(rte->rtekind == RTE_RELATION);
-	tsc = rte->tablesample;
+	Assert(best_path->path.parent->rtekind == RTE_RELATION);
+	tsc = best_path->tablesample;
 	Assert(tsc != NULL);
 
 	/* Sort clauses into best execution order */
@@ -2951,7 +2949,7 @@ create_samplescan_plan(PlannerInfo *root, Path *best_path,
 	scan_clauses = extract_actual_clauses(scan_clauses, false);
 
 	/* Replace any outer-relation variables with nestloop params */
-	if (best_path->param_info)
+	if (best_path->path.param_info)
 	{
 		scan_clauses = (List *)
 			replace_nestloop_params(root, (Node *) scan_clauses);
@@ -2964,7 +2962,7 @@ create_samplescan_plan(PlannerInfo *root, Path *best_path,
 								scan_relid,
 								tsc);
 
-	copy_generic_path_info(&scan_plan->scan.plan, best_path);
+	copy_generic_path_info(&scan_plan->scan.plan, &best_path->path);
 
 	return scan_plan;
 }
diff --git a/src/backend/optimizer/util/pathnode.c b/src/backend/optimizer/util/pathnode.c
index f123fcb41e..17b7ac1d8c 100644
--- a/src/backend/optimizer/util/pathnode.c
+++ b/src/backend/optimizer/util/pathnode.c
@@ -950,22 +950,25 @@ create_seqscan_path(PlannerInfo *root, RelOptInfo *rel,
  * create_samplescan_path
  *	  Creates a path node for a sampled table scan.
  */
-Path *
+SampleScanPath *
 create_samplescan_path(PlannerInfo *root, RelOptInfo *rel, Relids required_outer)
 {
-	Path	   *pathnode = makeNode(Path);
+	SampleScanPath	  *pathnode = makeNode(SampleScanPath);
+	RangeTblEntry	  *rte = planner_rt_fetch(rel->relid, root);
 
-	pathnode->pathtype = T_SampleScan;
-	pathnode->parent = rel;
-	pathnode->pathtarget = rel->reltarget;
-	pathnode->param_info = get_baserel_parampathinfo(root, rel,
-													 required_outer);
-	pathnode->parallel_aware = false;
-	pathnode->parallel_safe = rel->consider_parallel;
-	pathnode->parallel_workers = 0;
-	pathnode->pathkeys = NIL;	/* samplescan has unordered result */
+	pathnode->path.pathtype = T_SampleScan;
+	pathnode->path.parent = rel;
+	pathnode->path.pathtarget = rel->reltarget;
+	pathnode->path.param_info = get_baserel_parampathinfo(root, rel,
+														  required_outer);
+	pathnode->path.parallel_aware = false;
+	pathnode->path.parallel_safe = rel->consider_parallel;
+	pathnode->path.parallel_workers = 0;
+	pathnode->path.pathkeys = NIL;	/* samplescan has unordered result */
+
+	pathnode->tablesample = copyObject(rte->tablesample);
 
-	cost_samplescan(pathnode, root, rel, pathnode->param_info);
+	cost_samplescan(pathnode, root, rel, pathnode->path.param_info);
 
 	return pathnode;
 }
@@ -4047,11 +4050,13 @@ reparameterize_path_by_child(PlannerInfo *root, Path *path,
 	( (newnode) = makeNode(nodetype), \
 	  memcpy((newnode), (node), sizeof(nodetype)) )
 
-#define ADJUST_CHILD_ATTRS(node) \
+#define ADJUST_CHILD_EXPRS(node, fieldtype) \
 	((node) = \
-	 (List *) adjust_appendrel_attrs_multilevel(root, (Node *) (node), \
-												child_rel, \
-												child_rel->top_parent))
+	 (fieldtype) adjust_appendrel_attrs_multilevel(root, (Node *) (node), \
+												   child_rel, \
+												   child_rel->top_parent))
+
+#define ADJUST_CHILD_ATTRS(node) ADJUST_CHILD_EXPRS(node, List *)
 
 #define REPARAMETERIZE_CHILD_PATH(path) \
 do { \
@@ -4102,6 +4107,16 @@ do { \
 			FLAT_COPY_PATH(new_path, path, Path);
 			break;
 
+		case T_SampleScanPath:
+			{
+				SampleScanPath  *sspath;
+
+				FLAT_COPY_PATH(sspath, path, SampleScanPath);
+				ADJUST_CHILD_EXPRS(sspath->tablesample, TableSampleClause *);
+				new_path = (Path *) sspath;
+			}
+			break;
+
 		case T_IndexPath:
 			{
 				IndexPath  *ipath;
diff --git a/src/include/nodes/pathnodes.h b/src/include/nodes/pathnodes.h
index c17b53f7ad..2b54138e76 100644
--- a/src/include/nodes/pathnodes.h
+++ b/src/include/nodes/pathnodes.h
@@ -1637,6 +1637,17 @@ typedef struct Path
 #define PATH_REQ_OUTER(path)  \
 	((path)->param_info ? (path)->param_info->ppi_req_outer : (Relids) NULL)
 
+/*
+ * SampleScanPath represents a sample scan of a table.
+ *
+ * tablesample is the sampling info.
+ */
+typedef struct SampleScanPath
+{
+	Path		path;
+	TableSampleClause *tablesample;
+} SampleScanPath;
+
 /*----------
  * IndexPath represents an index scan over a single index.
  *
diff --git a/src/include/optimizer/cost.h b/src/include/optimizer/cost.h
index 6cf49705d3..4c8bdb630b 100644
--- a/src/include/optimizer/cost.h
+++ b/src/include/optimizer/cost.h
@@ -76,8 +76,8 @@ extern double index_pages_fetched(double tuples_fetched, BlockNumber pages,
 								  double index_pages, PlannerInfo *root);
 extern void cost_seqscan(Path *path, PlannerInfo *root, RelOptInfo *baserel,
 						 ParamPathInfo *param_info);
-extern void cost_samplescan(Path *path, PlannerInfo *root, RelOptInfo *baserel,
-							ParamPathInfo *param_info);
+extern void cost_samplescan(SampleScanPath *path, PlannerInfo *root,
+							RelOptInfo *baserel, ParamPathInfo *param_info);
 extern void cost_index(IndexPath *path, PlannerInfo *root,
 					   double loop_count, bool partial_path);
 extern void cost_bitmap_heap_scan(Path *path, PlannerInfo *root, RelOptInfo *baserel,
diff --git a/src/include/optimizer/pathnode.h b/src/include/optimizer/pathnode.h
index 001e75b5b7..b93a195589 100644
--- a/src/include/optimizer/pathnode.h
+++ b/src/include/optimizer/pathnode.h
@@ -36,8 +36,8 @@ extern bool add_partial_path_precheck(RelOptInfo *parent_rel,
 
 extern Path *create_seqscan_path(PlannerInfo *root, RelOptInfo *rel,
 								 Relids required_outer, int parallel_workers);
-extern Path *create_samplescan_path(PlannerInfo *root, RelOptInfo *rel,
-									Relids required_outer);
+extern SampleScanPath *create_samplescan_path(PlannerInfo *root, RelOptInfo *rel,
+											  Relids required_outer);
 extern IndexPath *create_index_path(PlannerInfo *root,
 									IndexOptInfo *index,
 									List *indexclauses,
diff --git a/src/test/regress/expected/partition_join.out b/src/test/regress/expected/partition_join.out
index 6560fe2416..a11f738411 100644
--- a/src/test/regress/expected/partition_join.out
+++ b/src/test/regress/expected/partition_join.out
@@ -505,6 +505,31 @@ SELECT t1.a, ss.t2a, ss.t2c FROM prt1 t1 LEFT JOIN LATERAL
  550 |     | 
 (12 rows)
 
+-- lateral reference in sample scan
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1 t1 JOIN LATERAL
+			  (SELECT * FROM prt1 t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s
+			  ON t1.a = s.a;
+                         QUERY PLAN                          
+-------------------------------------------------------------
+ Append
+   ->  Nested Loop
+         ->  Seq Scan on prt1_p1 t1_1
+         ->  Sample Scan on prt1_p1 t2_1
+               Sampling: system (t1_1.a) REPEATABLE (t1_1.b)
+               Filter: (t1_1.a = a)
+   ->  Nested Loop
+         ->  Seq Scan on prt1_p2 t1_2
+         ->  Sample Scan on prt1_p2 t2_2
+               Sampling: system (t1_2.a) REPEATABLE (t1_2.b)
+               Filter: (t1_2.a = a)
+   ->  Nested Loop
+         ->  Seq Scan on prt1_p3 t1_3
+         ->  Sample Scan on prt1_p3 t2_3
+               Sampling: system (t1_3.a) REPEATABLE (t1_3.b)
+               Filter: (t1_3.a = a)
+(16 rows)
+
 -- bug with inadequate sort key representation
 SET enable_partitionwise_aggregate TO true;
 SET enable_hashjoin TO false;
@@ -1944,6 +1969,41 @@ SELECT * FROM prt1_l t1 LEFT JOIN LATERAL
  550 | 0 | 0002 |     |      |     |     |      
 (12 rows)
 
+-- partitionwise join with lateral reference in sample scan
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1_l t1 JOIN LATERAL
+			  (SELECT * FROM prt1_l t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s ON
+			  t1.a = s.a AND t1.b = s.b AND t1.c = s.c;
+                                       QUERY PLAN                                       
+----------------------------------------------------------------------------------------
+ Append
+   ->  Nested Loop
+         ->  Seq Scan on prt1_l_p1 t1_1
+         ->  Sample Scan on prt1_l_p1 t2_1
+               Sampling: system (t1_1.a) REPEATABLE (t1_1.b)
+               Filter: ((t1_1.a = a) AND (t1_1.b = b) AND ((t1_1.c)::text = (c)::text))
+   ->  Nested Loop
+         ->  Seq Scan on prt1_l_p2_p1 t1_2
+         ->  Sample Scan on prt1_l_p2_p1 t2_2
+               Sampling: system (t1_2.a) REPEATABLE (t1_2.b)
+               Filter: ((t1_2.a = a) AND (t1_2.b = b) AND ((t1_2.c)::text = (c)::text))
+   ->  Nested Loop
+         ->  Seq Scan on prt1_l_p2_p2 t1_3
+         ->  Sample Scan on prt1_l_p2_p2 t2_3
+               Sampling: system (t1_3.a) REPEATABLE (t1_3.b)
+               Filter: ((t1_3.a = a) AND (t1_3.b = b) AND ((t1_3.c)::text = (c)::text))
+   ->  Nested Loop
+         ->  Seq Scan on prt1_l_p3_p1 t1_4
+         ->  Sample Scan on prt1_l_p3_p1 t2_4
+               Sampling: system (t1_4.a) REPEATABLE (t1_4.b)
+               Filter: ((t1_4.a = a) AND (t1_4.b = b) AND ((t1_4.c)::text = (c)::text))
+   ->  Nested Loop
+         ->  Seq Scan on prt1_l_p3_p2 t1_5
+         ->  Sample Scan on prt1_l_p3_p2 t2_5
+               Sampling: system (t1_5.a) REPEATABLE (t1_5.b)
+               Filter: ((t1_5.a = a) AND (t1_5.b = b) AND ((t1_5.c)::text = (c)::text))
+(26 rows)
+
 -- join with one side empty
 EXPLAIN (COSTS OFF)
 SELECT t1.a, t1.c, t2.b, t2.c FROM (SELECT * FROM prt1_l WHERE a = 1 AND a = 2) t1 RIGHT JOIN prt2_l t2 ON t1.a = t2.b AND t1.b = t2.a AND t1.c = t2.c;
diff --git a/src/test/regress/sql/partition_join.sql b/src/test/regress/sql/partition_join.sql
index 48daf3aee3..e2daab03fb 100644
--- a/src/test/regress/sql/partition_join.sql
+++ b/src/test/regress/sql/partition_join.sql
@@ -100,6 +100,12 @@ SELECT t1.a, ss.t2a, ss.t2c FROM prt1 t1 LEFT JOIN LATERAL
 			  (SELECT t2.a AS t2a, t3.a AS t3a, t2.b t2b, t2.c t2c, least(t1.a,t2.a,t3.a) FROM prt1 t2 JOIN prt2 t3 ON (t2.a = t3.b)) ss
 			  ON t1.c = ss.t2c WHERE (t1.b + coalesce(ss.t2b, 0)) = 0 ORDER BY t1.a;
 
+-- lateral reference in sample scan
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1 t1 JOIN LATERAL
+			  (SELECT * FROM prt1 t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s
+			  ON t1.a = s.a;
+
 -- bug with inadequate sort key representation
 SET enable_partitionwise_aggregate TO true;
 SET enable_hashjoin TO false;
@@ -387,6 +393,12 @@ SELECT * FROM prt1_l t1 LEFT JOIN LATERAL
 			  (SELECT t2.a AS t2a, t2.c AS t2c, t2.b AS t2b, t3.b AS t3b, least(t1.a,t2.a,t3.b) FROM prt1_l t2 JOIN prt2_l t3 ON (t2.a = t3.b AND t2.c = t3.c)) ss
 			  ON t1.a = ss.t2a AND t1.c = ss.t2c WHERE t1.b = 0 ORDER BY t1.a;
 
+-- partitionwise join with lateral reference in sample scan
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1_l t1 JOIN LATERAL
+			  (SELECT * FROM prt1_l t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s ON
+			  t1.a = s.a AND t1.b = s.b AND t1.c = s.c;
+
 -- join with one side empty
 EXPLAIN (COSTS OFF)
 SELECT t1.a, t1.c, t2.b, t2.c FROM (SELECT * FROM prt1_l WHERE a = 1 AND a = 2) t1 RIGHT JOIN prt2_l t2 ON t1.a = t2.b AND t1.b = t2.a AND t1.c = t2.c;
-- 
2.31.0

