From 0f5f23370fcbfa25b6705bcf4a4667e51c2dbaf9 Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Thu, 16 May 2024 06:17:37 +0000
Subject: [PATCH v4] Introduce a RTE for the grouping step

---
 .../postgres_fdw/expected/postgres_fdw.out    |   2 +-
 src/backend/commands/explain.c                |  21 +-
 src/backend/nodes/nodeFuncs.c                 |  14 ++
 src/backend/nodes/outfuncs.c                  |   3 +
 src/backend/nodes/print.c                     |   4 +
 src/backend/nodes/readfuncs.c                 |   3 +
 src/backend/optimizer/path/allpaths.c         |   4 +
 src/backend/optimizer/path/equivclass.c       |  12 +
 src/backend/optimizer/plan/initsplan.c        |   4 +
 src/backend/optimizer/plan/planner.c          |  31 ++-
 src/backend/optimizer/plan/setrefs.c          |   1 +
 src/backend/optimizer/prep/prepjointree.c     |   9 +-
 src/backend/optimizer/util/var.c              | 125 ++++++++++
 src/backend/parser/parse_agg.c                | 214 +++++++++++++++++-
 src/backend/parser/parse_clause.c             |   4 +-
 src/backend/parser/parse_relation.c           |  79 ++++++-
 src/backend/parser/parse_target.c             |   2 +
 src/backend/utils/adt/ruleutils.c             |  19 +-
 src/include/commands/explain.h                |   1 +
 src/include/nodes/nodeFuncs.h                 |   2 +
 src/include/nodes/parsenodes.h                |   7 +
 src/include/nodes/pathnodes.h                 |   5 +
 src/include/optimizer/optimizer.h             |   1 +
 src/include/parser/parse_clause.h             |   2 +
 src/include/parser/parse_node.h               |   2 +
 src/include/parser/parse_relation.h           |   2 +
 src/test/regress/expected/groupingsets.out    |  49 ++++
 src/test/regress/sql/groupingsets.sql         |  23 ++
 28 files changed, 624 insertions(+), 21 deletions(-)

diff --git a/contrib/postgres_fdw/expected/postgres_fdw.out b/contrib/postgres_fdw/expected/postgres_fdw.out
index 078b8a966f..edc8f1d51b 100644
--- a/contrib/postgres_fdw/expected/postgres_fdw.out
+++ b/contrib/postgres_fdw/expected/postgres_fdw.out
@@ -3669,7 +3669,7 @@ select count(*), sum(t1.c1), avg(t2.c1) from (select c1 from ft4 where c1 betwee
  Foreign Scan
    Output: (count(*)), (sum(ft4.c1)), (avg(ft5.c1))
    Relations: Aggregate on ((public.ft4) FULL JOIN (public.ft5))
-   Remote SQL: SELECT count(*), sum(s4.c1), avg(s5.c1) FROM ((SELECT c1 FROM "S 1"."T 3" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s4(c1) FULL JOIN (SELECT c1 FROM "S 1"."T 4" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s5(c1) ON (((s4.c1 = s5.c1))))
+   Remote SQL: SELECT count(*), sum(s5.c1), avg(s6.c1) FROM ((SELECT c1 FROM "S 1"."T 3" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s5(c1) FULL JOIN (SELECT c1 FROM "S 1"."T 4" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s6(c1) ON (((s5.c1 = s6.c1))))
 (4 rows)
 
 select count(*), sum(t1.c1), avg(t2.c1) from (select c1 from ft4 where c1 between 50 and 60) t1 full join (select c1 from ft5 where c1 between 50 and 60) t2 on (t1.c1 = t2.c1);
diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c
index 94511a5a02..6840c3d596 100644
--- a/src/backend/commands/explain.c
+++ b/src/backend/commands/explain.c
@@ -877,6 +877,7 @@ ExplainPrintPlan(ExplainState *es, QueryDesc *queryDesc)
 {
 	Bitmapset  *rels_used = NULL;
 	PlanState  *ps;
+	ListCell   *lc;
 
 	/* Set up ExplainState fields associated with this plan tree */
 	Assert(queryDesc->plannedstmt != NULL);
@@ -887,6 +888,14 @@ ExplainPrintPlan(ExplainState *es, QueryDesc *queryDesc)
 	es->deparse_cxt = deparse_context_for_plan_tree(queryDesc->plannedstmt,
 													es->rtable_names);
 	es->printed_subplans = NULL;
+	es->rtable_size = list_length(es->rtable);
+	foreach (lc, es->rtable)
+	{
+		RangeTblEntry *rte = lfirst_node(RangeTblEntry, lc);
+
+		if (rte->rtekind == RTE_GROUP)
+			es->rtable_size--;
+	}
 
 	/*
 	 * Sometimes we mark a Gather node as "invisible", which means that it's
@@ -2463,7 +2472,7 @@ show_plan_tlist(PlanState *planstate, List *ancestors, ExplainState *es)
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   plan,
 									   ancestors);
-	useprefix = list_length(es->rtable) > 1;
+	useprefix = es->rtable_size > 1;
 
 	/* Deparse each result column (we now include resjunk ones) */
 	foreach(lc, plan->targetlist)
@@ -2547,7 +2556,7 @@ show_upper_qual(List *qual, const char *qlabel,
 {
 	bool		useprefix;
 
-	useprefix = (list_length(es->rtable) > 1 || es->verbose);
+	useprefix = (es->rtable_size > 1 || es->verbose);
 	show_qual(qual, qlabel, planstate, ancestors, useprefix, es);
 }
 
@@ -2637,7 +2646,7 @@ show_grouping_sets(PlanState *planstate, Agg *agg,
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   planstate->plan,
 									   ancestors);
-	useprefix = (list_length(es->rtable) > 1 || es->verbose);
+	useprefix = (es->rtable_size > 1 || es->verbose);
 
 	ExplainOpenGroup("Grouping Sets", "Grouping Sets", false, es);
 
@@ -2777,7 +2786,7 @@ show_sort_group_keys(PlanState *planstate, const char *qlabel,
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   plan,
 									   ancestors);
-	useprefix = (list_length(es->rtable) > 1 || es->verbose);
+	useprefix = (es->rtable_size > 1 || es->verbose);
 
 	for (keyno = 0; keyno < nkeys; keyno++)
 	{
@@ -2889,7 +2898,7 @@ show_tablesample(TableSampleClause *tsc, PlanState *planstate,
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   planstate->plan,
 									   ancestors);
-	useprefix = list_length(es->rtable) > 1;
+	useprefix = es->rtable_size > 1;
 
 	/* Get the tablesample method name */
 	method_name = get_func_name(tsc->tsmhandler);
@@ -3339,7 +3348,7 @@ show_memoize_info(MemoizeState *mstate, List *ancestors, ExplainState *es)
 	 * It's hard to imagine having a memoize node with fewer than 2 RTEs, but
 	 * let's just keep the same useprefix logic as elsewhere in this file.
 	 */
-	useprefix = list_length(es->rtable) > 1 || es->verbose;
+	useprefix = es->rtable_size > 1 || es->verbose;
 
 	/* Set up deparsing context */
 	context = set_deparse_context_plan(es->deparse_cxt,
diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c
index 89ee4b61f2..6f0f8e8c54 100644
--- a/src/backend/nodes/nodeFuncs.c
+++ b/src/backend/nodes/nodeFuncs.c
@@ -2862,6 +2862,11 @@ range_table_entry_walker_impl(RangeTblEntry *rte,
 		case RTE_RESULT:
 			/* nothing to do */
 			break;
+		case RTE_GROUP:
+			if (!(flags & QTW_IGNORE_GROUPEXPRS))
+				if (WALK(rte->groupexprs))
+					return true;
+			break;
 	}
 
 	if (WALK(rte->securityQuals))
@@ -3900,6 +3905,15 @@ range_table_mutator_impl(List *rtable,
 			case RTE_RESULT:
 				/* nothing to do */
 				break;
+			case RTE_GROUP:
+				if (!(flags & QTW_IGNORE_GROUPEXPRS))
+					MUTATE(newrte->groupexprs, rte->groupexprs, List *);
+				else
+				{
+					/* else, copy group exprs as-is */
+					newrte->groupexprs = copyObject(rte->groupexprs);
+				}
+				break;
 		}
 		MUTATE(newrte->securityQuals, rte->securityQuals, List *);
 		newrt = lappend(newrt, newrte);
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index 3337b77ae6..9827cf16be 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -562,6 +562,9 @@ _outRangeTblEntry(StringInfo str, const RangeTblEntry *node)
 		case RTE_RESULT:
 			/* no extra fields */
 			break;
+		case RTE_GROUP:
+			WRITE_NODE_FIELD(groupexprs);
+			break;
 		default:
 			elog(ERROR, "unrecognized RTE kind: %d", (int) node->rtekind);
 			break;
diff --git a/src/backend/nodes/print.c b/src/backend/nodes/print.c
index 02798f4482..03416e8f4a 100644
--- a/src/backend/nodes/print.c
+++ b/src/backend/nodes/print.c
@@ -300,6 +300,10 @@ print_rt(const List *rtable)
 				printf("%d\t%s\t[result]",
 					   i, rte->eref->aliasname);
 				break;
+			case RTE_GROUP:
+				printf("%d\t%s\t[group]",
+					   i, rte->eref->aliasname);
+				break;
 			default:
 				printf("%d\t%s\t[unknown rtekind]",
 					   i, rte->eref->aliasname);
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index c4d01a441a..818e472a3b 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -422,6 +422,9 @@ _readRangeTblEntry(void)
 		case RTE_RESULT:
 			/* no extra fields */
 			break;
+		case RTE_GROUP:
+			READ_NODE_FIELD(groupexprs);
+			break;
 		default:
 			elog(ERROR, "unrecognized RTE kind: %d",
 				 (int) local_node->rtekind);
diff --git a/src/backend/optimizer/path/allpaths.c b/src/backend/optimizer/path/allpaths.c
index 4895cee994..2ee478195f 100644
--- a/src/backend/optimizer/path/allpaths.c
+++ b/src/backend/optimizer/path/allpaths.c
@@ -731,6 +731,10 @@ set_rel_consider_parallel(PlannerInfo *root, RelOptInfo *rel,
 		case RTE_RESULT:
 			/* RESULT RTEs, in themselves, are no problem. */
 			break;
+		case RTE_GROUP:
+			/* Shouldn't happen; we're only considering baserels here. */
+			Assert(false);
+			return;
 	}
 
 	/*
diff --git a/src/backend/optimizer/path/equivclass.c b/src/backend/optimizer/path/equivclass.c
index 21ce1ae2e1..61c450bb99 100644
--- a/src/backend/optimizer/path/equivclass.c
+++ b/src/backend/optimizer/path/equivclass.c
@@ -737,6 +737,10 @@ get_eclass_for_sort_expr(PlannerInfo *root,
 		{
 			RelOptInfo *rel = root->simple_rel_array[i];
 
+			/* ignore GROUP RTE */
+			if (i == root->group_rtindex)
+				continue;
+
 			if (rel == NULL)	/* must be an outer join */
 			{
 				Assert(bms_is_member(i, root->outer_join_rels));
@@ -1098,6 +1102,10 @@ generate_base_implied_equalities(PlannerInfo *root)
 		{
 			RelOptInfo *rel = root->simple_rel_array[i];
 
+			/* ignore GROUP RTE */
+			if (i == root->group_rtindex)
+				continue;
+
 			if (rel == NULL)	/* must be an outer join */
 			{
 				Assert(bms_is_member(i, root->outer_join_rels));
@@ -3353,6 +3361,10 @@ get_eclass_indexes_for_relids(PlannerInfo *root, Relids relids)
 	{
 		RelOptInfo *rel = root->simple_rel_array[i];
 
+		/* ignore GROUP RTE */
+		if (i == root->group_rtindex)
+			continue;
+
 		if (rel == NULL)		/* must be an outer join */
 		{
 			Assert(bms_is_member(i, root->outer_join_rels));
diff --git a/src/backend/optimizer/plan/initsplan.c b/src/backend/optimizer/plan/initsplan.c
index e2c68fe6f9..48fad35051 100644
--- a/src/backend/optimizer/plan/initsplan.c
+++ b/src/backend/optimizer/plan/initsplan.c
@@ -1328,6 +1328,10 @@ mark_rels_nulled_by_join(PlannerInfo *root, Index ojrelid,
 	{
 		RelOptInfo *rel = root->simple_rel_array[relid];
 
+		/* ignore GROUP RTE */
+		if (relid == root->group_rtindex)
+			continue;
+
 		if (rel == NULL)		/* must be an outer join */
 		{
 			Assert(bms_is_member(relid, root->outer_join_rels));
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 032818423f..b969aa3bcf 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -748,6 +748,7 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 	 */
 	root->hasJoinRTEs = false;
 	root->hasLateralRTEs = false;
+	root->group_rtindex = 0;
 	hasOuterJoins = false;
 	hasResultRTEs = false;
 	foreach(l, parse->rtable)
@@ -781,6 +782,9 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 			case RTE_RESULT:
 				hasResultRTEs = true;
 				break;
+			case RTE_GROUP:
+				root->group_rtindex = list_cell_number(parse->rtable, l) + 1;
+				break;
 			default:
 				/* No work here for other RTE types */
 				break;
@@ -836,10 +840,6 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 		preprocess_expression(root, (Node *) parse->targetList,
 							  EXPRKIND_TARGET);
 
-	/* Constant-folding might have removed all set-returning functions */
-	if (parse->hasTargetSRFs)
-		parse->hasTargetSRFs = expression_returns_set((Node *) parse->targetList);
-
 	newWithCheckOptions = NIL;
 	foreach(l, parse->withCheckOptions)
 	{
@@ -969,6 +969,13 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 			rte->values_lists = (List *)
 				preprocess_expression(root, (Node *) rte->values_lists, kind);
 		}
+		else if (rte->rtekind == RTE_GROUP)
+		{
+			/* Preprocess the groupexprs lists fully */
+			rte->groupexprs = (List *)
+				preprocess_expression(root, (Node *) rte->groupexprs,
+									  EXPRKIND_TARGET);
+		}
 
 		/*
 		 * Process each element of the securityQuals list as if it were a
@@ -984,6 +991,22 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 		}
 	}
 
+	/*
+	 * Replace any Vars that reference GROUP outputs in the subquery's
+	 * targetlist and havingQual with the underlying grouping expressions.
+	 */
+	if (root->group_rtindex > 0)
+	{
+		parse->targetList = (List *)
+			flatten_group_exprs(root, root->parse, (Node *) parse->targetList);
+		parse->havingQual =
+			flatten_group_exprs(root, root->parse, parse->havingQual);
+	}
+
+	/* Constant-folding might have removed all set-returning functions */
+	if (parse->hasTargetSRFs)
+		parse->hasTargetSRFs = expression_returns_set((Node *) parse->targetList);
+
 	/*
 	 * Now that we are done preprocessing expressions, and in particular done
 	 * flattening join alias variables, get rid of the joinaliasvars lists.
diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c
index 37abcb4701..631d4d2c70 100644
--- a/src/backend/optimizer/plan/setrefs.c
+++ b/src/backend/optimizer/plan/setrefs.c
@@ -557,6 +557,7 @@ add_rte_to_flat_rtable(PlannerGlobal *glob, List *rteperminfos,
 	newrte->coltypes = NIL;
 	newrte->coltypmods = NIL;
 	newrte->colcollations = NIL;
+	newrte->groupexprs = NIL;
 	newrte->securityQuals = NIL;
 
 	glob->finalrtable = lappend(glob->finalrtable, newrte);
diff --git a/src/backend/optimizer/prep/prepjointree.c b/src/backend/optimizer/prep/prepjointree.c
index 5482ab85a7..728c07f464 100644
--- a/src/backend/optimizer/prep/prepjointree.c
+++ b/src/backend/optimizer/prep/prepjointree.c
@@ -1235,6 +1235,7 @@ pull_up_simple_subquery(PlannerInfo *root, Node *jtnode, RangeTblEntry *rte,
 				case RTE_CTE:
 				case RTE_NAMEDTUPLESTORE:
 				case RTE_RESULT:
+				case RTE_GROUP:
 					/* these can't contain any lateral references */
 					break;
 			}
@@ -2218,7 +2219,8 @@ perform_pullup_replace_vars(PlannerInfo *root,
 	}
 
 	/*
-	 * Replace references in the joinaliasvars lists of join RTEs.
+	 * Replace references in the joinaliasvars lists of join RTEs and the
+	 * groupexprs list of group RTE.
 	 */
 	foreach(lc, parse->rtable)
 	{
@@ -2228,6 +2230,10 @@ perform_pullup_replace_vars(PlannerInfo *root,
 			otherrte->joinaliasvars = (List *)
 				pullup_replace_vars((Node *) otherrte->joinaliasvars,
 									rvcontext);
+		else if (otherrte->rtekind == RTE_GROUP)
+			otherrte->groupexprs = (List *)
+				pullup_replace_vars((Node *) otherrte->groupexprs,
+									rvcontext);
 	}
 }
 
@@ -2293,6 +2299,7 @@ replace_vars_in_jointree(Node *jtnode,
 					case RTE_CTE:
 					case RTE_NAMEDTUPLESTORE:
 					case RTE_RESULT:
+					case RTE_GROUP:
 						/* these shouldn't be marked LATERAL */
 						Assert(false);
 						break;
diff --git a/src/backend/optimizer/util/var.c b/src/backend/optimizer/util/var.c
index 844fc30978..fa7860bec7 100644
--- a/src/backend/optimizer/util/var.c
+++ b/src/backend/optimizer/util/var.c
@@ -81,6 +81,8 @@ static bool pull_var_clause_walker(Node *node,
 								   pull_var_clause_context *context);
 static Node *flatten_join_alias_vars_mutator(Node *node,
 											 flatten_join_alias_vars_context *context);
+static Node *flatten_group_exprs_mutator(Node *node,
+										 flatten_join_alias_vars_context *context);
 static Node *add_nullingrels_if_needed(PlannerInfo *root, Node *newnode,
 									   Var *oldvar);
 static bool is_standard_join_alias_expression(Node *newnode, Var *oldvar);
@@ -902,6 +904,129 @@ flatten_join_alias_vars_mutator(Node *node,
 								   (void *) context);
 }
 
+/*
+ * flatten_group_exprs
+ *	  Replace Vars that reference GROUP outputs with references to the original
+ *	  relation variables instead.
+ */
+Node *
+flatten_group_exprs(PlannerInfo *root, Query *query, Node *node)
+{
+	flatten_join_alias_vars_context context;
+
+	/*
+	 * We do not expect this to be applied to the whole Query, only to
+	 * expressions or LATERAL subqueries.  Hence, if the top node is a Query,
+	 * it's okay to immediately increment sublevels_up.
+	 */
+	Assert(node != (Node *) query);
+
+	context.root = root;
+	context.query = query;
+	context.sublevels_up = 0;
+	/* flag whether join aliases could possibly contain SubLinks */
+	context.possible_sublink = query->hasSubLinks;
+	/* if hasSubLinks is already true, no need to work hard */
+	context.inserted_sublink = query->hasSubLinks;
+
+	return flatten_group_exprs_mutator(node, &context);
+}
+
+static Node *
+flatten_group_exprs_mutator(Node *node,
+							flatten_join_alias_vars_context *context)
+{
+	if (node == NULL)
+		return NULL;
+	if (IsA(node, Var))
+	{
+		Var		   *var = (Var *) node;
+		RangeTblEntry *rte;
+		Node	   *newvar;
+
+		/* No change unless Var belongs to the GROUP of the target level */
+		if (var->varlevelsup != context->sublevels_up)
+			return node;		/* no need to copy, really */
+		rte = rt_fetch(var->varno, context->query->rtable);
+		if (rte->rtekind != RTE_GROUP)
+			return node;
+
+		/* Expand group exprs reference */
+		Assert(var->varattno > 0);
+		newvar = (Node *) list_nth(rte->groupexprs, var->varattno - 1);
+		Assert(newvar != NULL);
+		newvar = copyObject(newvar);
+
+		/*
+		 * If we are expanding an expr carried down from an upper query, must
+		 * adjust its varlevelsup fields.
+		 */
+		if (context->sublevels_up != 0)
+			IncrementVarSublevelsUp(newvar, context->sublevels_up, 0);
+
+		/* Preserve original Var's location, if possible */
+		if (IsA(newvar, Var))
+			((Var *) newvar)->location = var->location;
+
+		/* Detect if we are adding a sublink to query */
+		if (context->possible_sublink && !context->inserted_sublink)
+			context->inserted_sublink = checkExprHasSubLink(newvar);
+
+		/*
+		 * TODO var->varnullingrels might have the nullingrel bit that
+		 * references RTE_GROUP.  We're supposed to add it to the replacement
+		 * expression.
+		 *
+		 * Maybe we can do something like add_nullingrels_if_needed().
+		 */
+		return newvar;
+	}
+
+	if (IsA(node, Aggref))
+	{
+		Aggref	   *agg = (Aggref *) node;
+
+		if ((int) agg->agglevelsup > context->sublevels_up)
+			return node;
+
+		agg = copyObject(agg);
+		agg->aggdirectargs = (List *)
+			flatten_group_exprs_mutator((Node *) agg->aggdirectargs, context);
+
+		return (Node *) agg;
+	}
+
+	if (IsA(node, GroupingFunc))
+	{
+		GroupingFunc *grp = (GroupingFunc *) node;
+
+		if ((int) grp->agglevelsup >= context->sublevels_up)
+			return node;
+	}
+
+	if (IsA(node, Query))
+	{
+		/* Recurse into RTE subquery or not-yet-planned sublink subquery */
+		Query	   *newnode;
+		bool		save_inserted_sublink;
+
+		context->sublevels_up++;
+		save_inserted_sublink = context->inserted_sublink;
+		context->inserted_sublink = ((Query *) node)->hasSubLinks;
+		newnode = query_tree_mutator((Query *) node,
+									 flatten_group_exprs_mutator,
+									 (void *) context,
+									 QTW_IGNORE_GROUPEXPRS);
+		newnode->hasSubLinks |= context->inserted_sublink;
+		context->inserted_sublink = save_inserted_sublink;
+		context->sublevels_up--;
+		return (Node *) newnode;
+	}
+
+	return expression_tree_mutator(node, flatten_group_exprs_mutator,
+								   (void *) context);
+}
+
 /*
  * Add oldvar's varnullingrels, if any, to a flattened join alias expression.
  * The newnode has been copied, so we can modify it freely.
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index bee7d8346a..7e2ec2ef4a 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -26,6 +26,7 @@
 #include "parser/parse_clause.h"
 #include "parser/parse_coerce.h"
 #include "parser/parse_expr.h"
+#include "parser/parse_relation.h"
 #include "parser/parsetree.h"
 #include "rewrite/rewriteManip.h"
 #include "utils/builtins.h"
@@ -53,6 +54,15 @@ typedef struct
 	bool		in_agg_direct_args;
 } check_ungrouped_columns_context;
 
+typedef struct
+{
+	ParseState *pstate;
+	List	   *groupClauses;
+	List	   *groupClauseCommonExprs;
+	bool		have_non_var_grouping;
+	int			sublevels_up;
+} substitute_group_exprs_context;
+
 static int	check_agg_arguments(ParseState *pstate,
 								List *directargs,
 								List *args,
@@ -65,6 +75,11 @@ static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
 									List **func_grouped_rels);
 static bool check_ungrouped_columns_walker(Node *node,
 										   check_ungrouped_columns_context *context);
+static Node *substitute_group_exprs(Node *node, ParseState *pstate,
+									List *groupClauses, List *groupClauseCommonExprs,
+									bool have_non_var_grouping);
+static Node *substitute_group_exprs_mutator(Node *node,
+											substitute_group_exprs_context *context);
 static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
 									List *groupClauses, bool hasJoinRTEs,
 									bool have_non_var_grouping);
@@ -1082,6 +1097,7 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 	List	   *gset_common = NIL;
 	List	   *groupClauses = NIL;
 	List	   *groupClauseCommonVars = NIL;
+	List	   *groupClauseCommonExprs = NIL;
 	bool		have_non_var_grouping;
 	List	   *func_grouped_rels = NIL;
 	ListCell   *l;
@@ -1201,13 +1217,26 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 		{
 			have_non_var_grouping = true;
 		}
-		else if (!qry->groupingSets ||
-				 list_member_int(gset_common, tle->ressortgroupref))
+
+		if (!qry->groupingSets ||
+			list_member_int(gset_common, tle->ressortgroupref))
 		{
-			groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr);
+			groupClauseCommonExprs = lappend(groupClauseCommonExprs, tle->expr);
+
+			if (IsA(tle->expr, Var))
+				groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr);
 		}
+
 	}
 
+	/*
+	 * Now build an RTE and nsitem for the result of the grouping step.
+	 */
+	pstate->p_grouping_nsitem =
+		addRangeTableEntryForGroup(pstate, groupClauses);
+
+	qry->rtable = pstate->p_rtable;
+
 	/*
 	 * Check the targetlist and HAVING clause for ungrouped variables.
 	 *
@@ -1241,6 +1270,15 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 							have_non_var_grouping,
 							&func_grouped_rels);
 
+	qry->targetList = (List *)
+		substitute_group_exprs((Node *) qry->targetList, pstate,
+							   groupClauses, groupClauseCommonExprs,
+							   have_non_var_grouping);
+	qry->havingQual =
+		substitute_group_exprs(qry->havingQual, pstate,
+							   groupClauses, groupClauseCommonExprs,
+							   have_non_var_grouping);
+
 	/*
 	 * Per spec, aggregates can't appear in a recursive term.
 	 */
@@ -1470,6 +1508,176 @@ check_ungrouped_columns_walker(Node *node,
 								  (void *) context);
 }
 
+static Node *
+substitute_group_exprs(Node *node, ParseState *pstate,
+					   List *groupClauses, List *groupClauseCommonExprs,
+					   bool have_non_var_grouping)
+{
+	substitute_group_exprs_context context;
+
+	context.pstate = pstate;
+	context.groupClauses = groupClauses;
+	context.groupClauseCommonExprs = groupClauseCommonExprs;
+	context.have_non_var_grouping = have_non_var_grouping;
+	context.sublevels_up = 0;
+	return substitute_group_exprs_mutator(node, &context);
+}
+
+static Node *
+substitute_group_exprs_mutator(Node *node,
+							   substitute_group_exprs_context *context)
+{
+	ListCell   *gl;
+
+	if (node == NULL)
+		return NULL;
+
+	if (IsA(node, Aggref))
+	{
+		Aggref	   *agg = (Aggref *) node;
+
+		if ((int) agg->agglevelsup == context->sublevels_up)
+		{
+			/*
+			 * If we find an aggregate call of the original level, do not
+			 * recurse into its normal arguments, ORDER BY arguments, or
+			 * filter; grouped vars there do not need to be replaced.  But we
+			 * should modify direct arguments as though they weren't in an
+			 * aggregate.
+			 */
+			agg = copyObject(agg);
+			agg->aggdirectargs = (List *)
+				substitute_group_exprs_mutator((Node *) agg->aggdirectargs,
+											   context);
+			return (Node *) agg;
+		}
+
+		/*
+		 * We can skip recursing into aggregates of higher levels altogether,
+		 * since they could not possibly contain Vars of concern to us (see
+		 * transformAggregateCall).  We do need to look at aggregates of lower
+		 * levels, however.
+		 */
+		if ((int) agg->agglevelsup > context->sublevels_up)
+			return node;
+	}
+
+	if (IsA(node, GroupingFunc))
+	{
+		GroupingFunc *grp = (GroupingFunc *) node;
+
+		if ((int) grp->agglevelsup >= context->sublevels_up)
+			return node;
+	}
+
+	/*
+	 * If we have any GROUP BY items that are not simple Vars, check to see if
+	 * subexpression as a whole matches any GROUP BY item. We need to do this
+	 * at every recursion level so that we recognize GROUPed-BY expressions
+	 * before reaching variables within them. But this only works at the outer
+	 * query level, as noted above.
+	 */
+	if (context->have_non_var_grouping && context->sublevels_up == 0)
+	{
+		int attnum = 0;
+		foreach(gl, context->groupClauses)
+		{
+			TargetEntry *tle = lfirst(gl);
+
+			attnum++;
+			if (equal(node, tle->expr))
+			{
+				Var    *newvar;
+				int		group_rtindex;
+				ParseNamespaceColumn *group_nscolumns;
+
+				group_rtindex = context->pstate->p_grouping_nsitem->p_rtindex;
+				group_nscolumns = context->pstate->p_grouping_nsitem->p_nscolumns;
+
+				newvar = buildVarFromNSColumn(context->pstate,
+											  group_nscolumns + attnum - 1);
+
+				if (!list_member(context->groupClauseCommonExprs, node))
+					newvar->varnullingrels =
+						bms_add_member(newvar->varnullingrels, group_rtindex);
+
+				return (Node *) newvar;
+			}
+		}
+	}
+
+	if (IsA(node, Const) ||
+		IsA(node, Param))
+		return node;
+
+	/*
+	 * We are only interested in Vars of the original query level.
+	 */
+	if (IsA(node, Var))
+	{
+		Var		   *var = (Var *) node;
+
+		if (var->varlevelsup != context->sublevels_up)
+			return node;		/* it's not local to my query, ignore */
+
+		/*
+		 * Check for a match, if we didn't do it above.
+		 */
+		if (!context->have_non_var_grouping || context->sublevels_up != 0)
+		{
+			int attnum = 0;
+			foreach(gl, context->groupClauses)
+			{
+				Var		   *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr;
+
+				attnum++;
+				if (IsA(gvar, Var) &&
+					gvar->varno == var->varno &&
+					gvar->varattno == var->varattno &&
+					gvar->varlevelsup == 0)
+				{
+					Var    *newvar;
+					int		group_rtindex;
+					ParseNamespaceColumn *group_nscolumns;
+
+					group_rtindex =
+						context->pstate->p_grouping_nsitem->p_rtindex;
+					group_nscolumns =
+						context->pstate->p_grouping_nsitem->p_nscolumns;
+
+					newvar = buildVarFromNSColumn(context->pstate,
+												  group_nscolumns + attnum - 1);
+					newvar->varlevelsup = context->sublevels_up;
+
+					if (!list_member(context->groupClauseCommonExprs, node))
+						newvar->varnullingrels =
+							bms_add_member(newvar->varnullingrels, group_rtindex);
+
+					return (Node *) newvar;
+				}
+			}
+		}
+
+		return node;
+	}
+
+	if (IsA(node, Query))
+	{
+		/* Recurse into subselects */
+		Query	   *newnode;
+
+		context->sublevels_up++;
+		newnode = query_tree_mutator((Query *) node,
+									 substitute_group_exprs_mutator,
+									 (void *) context,
+									 0);
+		context->sublevels_up--;
+		return (Node *) newnode;
+	}
+	return expression_tree_mutator(node, substitute_group_exprs_mutator,
+								   (void *) context);
+}
+
 /*
  * finalize_grouping_exprs -
  *	  Scan the given expression tree for GROUPING() and related calls,
diff --git a/src/backend/parser/parse_clause.c b/src/backend/parser/parse_clause.c
index 8118036495..350ca1d515 100644
--- a/src/backend/parser/parse_clause.c
+++ b/src/backend/parser/parse_clause.c
@@ -74,8 +74,6 @@ static ParseNamespaceItem *getNSItemForSpecialRelationTypes(ParseState *pstate,
 static Node *transformFromClauseItem(ParseState *pstate, Node *n,
 									 ParseNamespaceItem **top_nsitem,
 									 List **namespace);
-static Var *buildVarFromNSColumn(ParseState *pstate,
-								 ParseNamespaceColumn *nscol);
 static Node *buildMergedJoinVar(ParseState *pstate, JoinType jointype,
 								Var *l_colvar, Var *r_colvar);
 static void markRelsAsNulledBy(ParseState *pstate, Node *n, int jindex);
@@ -1636,7 +1634,7 @@ transformFromClauseItem(ParseState *pstate, Node *n,
  * Note also that no column SELECT privilege is requested here; that would
  * happen only if the column is actually referenced in the query.
  */
-static Var *
+Var *
 buildVarFromNSColumn(ParseState *pstate, ParseNamespaceColumn *nscol)
 {
 	Var		   *var;
diff --git a/src/backend/parser/parse_relation.c b/src/backend/parser/parse_relation.c
index 2f64eaf0e3..6947638425 100644
--- a/src/backend/parser/parse_relation.c
+++ b/src/backend/parser/parse_relation.c
@@ -2557,6 +2557,79 @@ addRangeTableEntryForENR(ParseState *pstate,
 									tupdesc);
 }
 
+/*
+ * Add an entry for grouping step to the pstate's range table (p_rtable).
+ * Then, construct and return a ParseNamespaceItem for the new RTE.
+ */
+ParseNamespaceItem *
+addRangeTableEntryForGroup(ParseState *pstate,
+						   List *groupClauses)
+{
+	RangeTblEntry *rte = makeNode(RangeTblEntry);
+	Alias	   *eref;
+	List	   *groupexprs;
+	List	   *coltypes,
+			   *coltypmods,
+			   *colcollations;
+	ListCell   *lc;
+	ParseNamespaceItem *nsitem;
+
+	Assert(pstate != NULL);
+
+	rte->rtekind = RTE_GROUP;
+	rte->alias = NULL;
+
+	eref = makeAlias("*GROUP*", NIL);
+
+	/* fill in any unspecified alias columns, and extract column type info */
+	groupexprs = NIL;
+	coltypes = coltypmods = colcollations = NIL;
+	foreach(lc, groupClauses)
+	{
+		TargetEntry *te = (TargetEntry *) lfirst(lc);
+		char	   *colname = te->resname ? pstrdup(te->resname) : "unamed_col";
+
+		eref->colnames = lappend(eref->colnames, makeString(colname));
+
+		groupexprs = lappend(groupexprs, copyObject(te->expr));
+
+		coltypes = lappend_oid(coltypes,
+							   exprType((Node *) te->expr));
+		coltypmods = lappend_int(coltypmods,
+								 exprTypmod((Node *) te->expr));
+		colcollations = lappend_oid(colcollations,
+									exprCollation((Node *) te->expr));
+	}
+
+	rte->eref = eref;
+	rte->groupexprs = groupexprs;
+
+	/*
+	 * Set flags.
+	 *
+	 * The grouping step is never checked for access rights, so no need to
+	 * perform addRTEPermissionInfo().
+	 */
+	rte->lateral = false;
+	rte->inFromCl = false;
+
+	/*
+	 * Add completed RTE to pstate's range table list, so that we know its
+	 * index.  But we don't add it to the join list --- caller must do that if
+	 * appropriate.
+	 */
+	pstate->p_rtable = lappend(pstate->p_rtable, rte);
+
+	/*
+	 * Build a ParseNamespaceItem, but don't add it to the pstate's namespace
+	 * list --- caller must do that if appropriate.
+	 */
+	nsitem = buildNSItemFromLists(rte, list_length(pstate->p_rtable),
+								  coltypes, coltypmods, colcollations);
+
+	return nsitem;
+}
+
 
 /*
  * Has the specified refname been selected FOR UPDATE/FOR SHARE?
@@ -3003,6 +3076,7 @@ expandRTE(RangeTblEntry *rte, int rtindex, int sublevels_up,
 			}
 			break;
 		case RTE_RESULT:
+		case RTE_GROUP:
 			/* These expose no columns, so nothing to do */
 			break;
 		default:
@@ -3317,10 +3391,11 @@ get_rte_attribute_is_dropped(RangeTblEntry *rte, AttrNumber attnum)
 		case RTE_TABLEFUNC:
 		case RTE_VALUES:
 		case RTE_CTE:
+		case RTE_GROUP:
 
 			/*
-			 * Subselect, Table Functions, Values, CTE RTEs never have dropped
-			 * columns
+			 * Subselect, Table Functions, Values, CTE, GROUP RTEs never have
+			 * dropped columns
 			 */
 			result = false;
 			break;
diff --git a/src/backend/parser/parse_target.c b/src/backend/parser/parse_target.c
index ee6fcd0503..1f8edc05c9 100644
--- a/src/backend/parser/parse_target.c
+++ b/src/backend/parser/parse_target.c
@@ -380,6 +380,7 @@ markTargetListOrigin(ParseState *pstate, TargetEntry *tle,
 		case RTE_TABLEFUNC:
 		case RTE_NAMEDTUPLESTORE:
 		case RTE_RESULT:
+		case RTE_GROUP:
 			/* not a simple relation, leave it unmarked */
 			break;
 		case RTE_CTE:
@@ -1579,6 +1580,7 @@ expandRecordVariable(ParseState *pstate, Var *var, int levelsup)
 		case RTE_VALUES:
 		case RTE_NAMEDTUPLESTORE:
 		case RTE_RESULT:
+		case RTE_GROUP:
 
 			/*
 			 * This case should not occur: a column of a table, values list,
diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c
index 9618619762..f539693bfe 100644
--- a/src/backend/utils/adt/ruleutils.c
+++ b/src/backend/utils/adt/ruleutils.c
@@ -5433,11 +5433,27 @@ get_query_def(Query *query, StringInfo buf, List *parentnamespace,
 {
 	deparse_context context;
 	deparse_namespace dpns;
+	int			rtable_size;
+	ListCell   *lc;
 
 	/* Guard against excessively long or deeply-nested queries */
 	CHECK_FOR_INTERRUPTS();
 	check_stack_depth();
 
+	rtable_size = list_length(query->rtable);
+	foreach (lc, query->rtable)
+	{
+		RangeTblEntry *rte = lfirst_node(RangeTblEntry, lc);
+
+		if (rte->rtekind == RTE_GROUP)
+			rtable_size--;
+	}
+
+	query->targetList = (List *)
+		flatten_group_exprs(NULL, query, (Node *) query->targetList);
+	query->havingQual =
+		flatten_group_exprs(NULL, query, query->havingQual);
+
 	/*
 	 * Before we begin to examine the query, acquire locks on referenced
 	 * relations, and fix up deleted columns in JOIN RTEs.  This ensures
@@ -5454,7 +5470,7 @@ get_query_def(Query *query, StringInfo buf, List *parentnamespace,
 	context.windowClause = NIL;
 	context.windowTList = NIL;
 	context.varprefix = (parentnamespace != NIL ||
-						 list_length(query->rtable) != 1);
+						 rtable_size != 1);
 	context.prettyFlags = prettyFlags;
 	context.wrapColumn = wrapColumn;
 	context.indentLevel = startIndent;
@@ -7838,6 +7854,7 @@ get_name_for_var_field(Var *var, int fieldno,
 		case RTE_VALUES:
 		case RTE_NAMEDTUPLESTORE:
 		case RTE_RESULT:
+		case RTE_GROUP:
 
 			/*
 			 * This case should not occur: a column of a table, values list,
diff --git a/src/include/commands/explain.h b/src/include/commands/explain.h
index 9b8b351d9a..35be084869 100644
--- a/src/include/commands/explain.h
+++ b/src/include/commands/explain.h
@@ -67,6 +67,7 @@ typedef struct ExplainState
 	List	   *deparse_cxt;	/* context list for deparsing expressions */
 	Bitmapset  *printed_subplans;	/* ids of SubPlans we've printed */
 	bool		hide_workers;	/* set if we find an invisible Gather */
+	int			rtable_size;	/* length of rtable excluding GROUP entries */
 	/* state related to the current plan node */
 	ExplainWorkersState *workers_state; /* needed if parallel plan */
 } ExplainState;
diff --git a/src/include/nodes/nodeFuncs.h b/src/include/nodes/nodeFuncs.h
index eaba59bed8..1f0de5b3d8 100644
--- a/src/include/nodes/nodeFuncs.h
+++ b/src/include/nodes/nodeFuncs.h
@@ -31,6 +31,8 @@ struct PlanState;				/* avoid including execnodes.h too */
 #define QTW_DONT_COPY_QUERY			0x40	/* do not copy top Query */
 #define QTW_EXAMINE_SORTGROUP		0x80	/* include SortGroupClause lists */
 
+#define QTW_IGNORE_GROUPEXPRS		0x100	/* GROUP expressions lists */
+
 /* callback function for check_functions_in_node */
 typedef bool (*check_function_callback) (Oid func_id, void *context);
 
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index ddfed02db2..a7b6fd3976 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -1036,6 +1036,7 @@ typedef enum RTEKind
 	RTE_RESULT,					/* RTE represents an empty FROM clause; such
 								 * RTEs are added by the planner, they're not
 								 * present during parsing or rewriting */
+	RTE_GROUP,					/* the grouping step */
 } RTEKind;
 
 typedef struct RangeTblEntry
@@ -1242,6 +1243,12 @@ typedef struct RangeTblEntry
 	/* estimated or actual from caller */
 	Cardinality enrtuples pg_node_attr(query_jumble_ignore);
 
+	/*
+	 * Fields valid for GROUP RTEs (else NULL/zero):
+	 */
+	/* list of expressions grouped on */
+	List	   *groupexprs pg_node_attr(query_jumble_ignore);
+
 	/*
 	 * Fields valid in all RTEs:
 	 */
diff --git a/src/include/nodes/pathnodes.h b/src/include/nodes/pathnodes.h
index 14ef296ab7..c082693e7c 100644
--- a/src/include/nodes/pathnodes.h
+++ b/src/include/nodes/pathnodes.h
@@ -505,6 +505,11 @@ struct PlannerInfo
 	/* true if planning a recursive WITH item */
 	bool		hasRecursion;
 
+	/*
+	 * The rangetable index for the GROUP RTE, or 0 if there is no GROUP RTE.
+	 */
+	int			group_rtindex;
+
 	/*
 	 * Information about aggregates. Filled by preprocess_aggrefs().
 	 */
diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h
index 7b63c5cf71..93e3dc719d 100644
--- a/src/include/optimizer/optimizer.h
+++ b/src/include/optimizer/optimizer.h
@@ -201,5 +201,6 @@ extern bool contain_vars_of_level(Node *node, int levelsup);
 extern int	locate_var_of_level(Node *node, int levelsup);
 extern List *pull_var_clause(Node *node, int flags);
 extern Node *flatten_join_alias_vars(PlannerInfo *root, Query *query, Node *node);
+extern Node *flatten_group_exprs(PlannerInfo *root, Query *query, Node *node);
 
 #endif							/* OPTIMIZER_H */
diff --git a/src/include/parser/parse_clause.h b/src/include/parser/parse_clause.h
index e71762b10c..1a1cf3570e 100644
--- a/src/include/parser/parse_clause.h
+++ b/src/include/parser/parse_clause.h
@@ -17,6 +17,8 @@
 #include "parser/parse_node.h"
 
 extern void transformFromClause(ParseState *pstate, List *frmList);
+extern Var *buildVarFromNSColumn(ParseState *pstate,
+								 ParseNamespaceColumn *nscol);
 extern int	setTargetTable(ParseState *pstate, RangeVar *relation,
 						   bool inh, bool alsoSource, AclMode requiredPerms);
 
diff --git a/src/include/parser/parse_node.h b/src/include/parser/parse_node.h
index 5b781d87a9..ef78fd8224 100644
--- a/src/include/parser/parse_node.h
+++ b/src/include/parser/parse_node.h
@@ -237,6 +237,8 @@ struct ParseState
 	ParseParamRefHook p_paramref_hook;
 	CoerceParamHook p_coerce_param_hook;
 	void	   *p_ref_hook_state;	/* common passthrough link for above */
+
+	ParseNamespaceItem *p_grouping_nsitem;	/* NSItem for grouping, or NULL */
 };
 
 /*
diff --git a/src/include/parser/parse_relation.h b/src/include/parser/parse_relation.h
index bea2da5496..91fd8e243b 100644
--- a/src/include/parser/parse_relation.h
+++ b/src/include/parser/parse_relation.h
@@ -100,6 +100,8 @@ extern ParseNamespaceItem *addRangeTableEntryForCTE(ParseState *pstate,
 extern ParseNamespaceItem *addRangeTableEntryForENR(ParseState *pstate,
 													RangeVar *rv,
 													bool inFromCl);
+extern ParseNamespaceItem *addRangeTableEntryForGroup(ParseState *pstate,
+													  List *groupClauses);
 extern RTEPermissionInfo *addRTEPermissionInfo(List **rteperminfos,
 											   RangeTblEntry *rte);
 extern RTEPermissionInfo *getRTEPermissionInfo(List *rteperminfos,
diff --git a/src/test/regress/expected/groupingsets.out b/src/test/regress/expected/groupingsets.out
index e1f0660810..9c7590e7ba 100644
--- a/src/test/regress/expected/groupingsets.out
+++ b/src/test/regress/expected/groupingsets.out
@@ -2150,4 +2150,53 @@ select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1;
         0
 (1 row)
 
+-- test handling of subqueries in grouping sets
+create temp table gstest5(id integer primary key, v integer);
+insert into gstest5 select i, i from generate_series(1,5)i;
+explain (costs off)
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+                                                QUERY PLAN                                                 
+-----------------------------------------------------------------------------------------------------------
+ Sort
+   Sort Key: (CASE WHEN (GROUPING((SubPlan 2)) = 0) THEN ((SubPlan 3)) ELSE NULL::integer END) NULLS FIRST
+   ->  HashAggregate
+         Hash Key: t1.v
+         Hash Key: (SubPlan 3)
+         ->  Seq Scan on gstest5 t1
+               SubPlan 3
+                 ->  Bitmap Heap Scan on gstest5 t2
+                       Recheck Cond: (id = t1.id)
+                       ->  Bitmap Index Scan on gstest5_pkey
+                             Index Cond: (id = t1.id)
+(11 rows)
+
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+ grouping | s 
+----------+---
+        1 |  
+        1 |  
+        1 |  
+        1 |  
+        1 |  
+        0 | 1
+        0 | 2
+        0 | 3
+        0 | 4
+        0 | 5
+(10 rows)
+
 -- end
diff --git a/src/test/regress/sql/groupingsets.sql b/src/test/regress/sql/groupingsets.sql
index 90ba27257a..0520e44aeb 100644
--- a/src/test/regress/sql/groupingsets.sql
+++ b/src/test/regress/sql/groupingsets.sql
@@ -589,4 +589,27 @@ explain (costs off)
 select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1;
 select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1;
 
+-- test handling of subqueries in grouping sets
+create temp table gstest5(id integer primary key, v integer);
+insert into gstest5 select i, i from generate_series(1,5)i;
+
+explain (costs off)
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+
 -- end
-- 
2.34.1

