From 925f06e1c884d24b4ec2ad517c222bd782c40bdf Mon Sep 17 00:00:00 2001
From: Melanie Plageman <melanieplageman@gmail.com>
Date: Thu, 13 Feb 2020 17:12:43 -0800
Subject: [PATCH v1] Find aggregated and unaggregated columns in same function

---
 src/backend/executor/nodeAgg.c | 82 +++++++++++++++++++++++-----------
 1 file changed, 57 insertions(+), 25 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index b7f49ceddf..2ba321f279 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -270,8 +270,11 @@ static void finalize_aggregates(AggState *aggstate,
 								AggStatePerAgg peragg,
 								AggStatePerGroup pergroup);
 static TupleTableSlot *project_aggregates(AggState *aggstate);
-static Bitmapset *find_unaggregated_cols(AggState *aggstate);
-static bool find_unaggregated_cols_walker(Node *node, Bitmapset **colnos);
+
+static bool find_aggregated_cols_walker(Node *node, void *context);
+static bool find_unaggregated_cols_walker(Node *node, void *context);
+static void find_cols(AggState *aggstate, Bitmapset **aggregated_colnos, Bitmapset **unaggregated_colnos);
+
 static void build_hash_table(AggState *aggstate);
 static TupleHashEntryData *lookup_hash_entry(AggState *aggstate);
 static void lookup_hash_entries(AggState *aggstate);
@@ -1189,30 +1192,38 @@ project_aggregates(AggState *aggstate)
 	return NULL;
 }
 
-/*
- * find_unaggregated_cols
- *	  Construct a bitmapset of the column numbers of un-aggregated Vars
- *	  appearing in our targetlist and qual (HAVING clause)
- */
-static Bitmapset *
-find_unaggregated_cols(AggState *aggstate)
+
+typedef struct FindColsContext
 {
-	Agg		   *node = (Agg *) aggstate->ss.ps.plan;
-	Bitmapset  *colnos;
-
-	colnos = NULL;
-	(void) find_unaggregated_cols_walker((Node *) node->plan.targetlist,
-										 &colnos);
-	(void) find_unaggregated_cols_walker((Node *) node->plan.qual,
-										 &colnos);
-	return colnos;
+	Bitmapset *aggregated_colnos;
+	Bitmapset *unaggregated_colnos;
+} FindColsContext;
+
+static bool
+find_aggregated_cols_walker(Node *node, void *context)
+{
+	if (node == NULL)
+		return false;
+
+	FindColsContext *find_cols_context = (FindColsContext *) context;
+
+	if (IsA(node, Var))
+	{
+		Var *var = (Var *) node;
+		find_cols_context->aggregated_colnos = bms_add_member(find_cols_context->aggregated_colnos, var->varattno);
+		return false;
+	}
+	return expression_tree_walker(node, find_aggregated_cols_walker, (void *) find_cols_context);
 }
 
 static bool
-find_unaggregated_cols_walker(Node *node, Bitmapset **colnos)
+find_unaggregated_cols_walker(Node *node, void *context)
 {
 	if (node == NULL)
 		return false;
+
+	FindColsContext *find_cols_context = (FindColsContext *) context;
+
 	if (IsA(node, Var))
 	{
 		Var		   *var = (Var *) node;
@@ -1220,18 +1231,36 @@ find_unaggregated_cols_walker(Node *node, Bitmapset **colnos)
 		/* setrefs.c should have set the varno to OUTER_VAR */
 		Assert(var->varno == OUTER_VAR);
 		Assert(var->varlevelsup == 0);
-		*colnos = bms_add_member(*colnos, var->varattno);
+		/*
+		 *	  Construct a bitmapset of the column numbers of un-aggregated Vars
+		 *	  appearing in our targetlist and qual (HAVING clause)
+		 */
+		find_cols_context->unaggregated_colnos = bms_add_member(find_cols_context->unaggregated_colnos, var->varattno);
 		return false;
 	}
 	if (IsA(node, Aggref) ||IsA(node, GroupingFunc))
 	{
-		/* do not descend into aggregate exprs */
-		return false;
+		return find_aggregated_cols_walker(node, (void *) find_cols_context);
 	}
-	return expression_tree_walker(node, find_unaggregated_cols_walker,
-								  (void *) colnos);
+	return expression_tree_walker(node, find_unaggregated_cols_walker, (void *) find_cols_context);
 }
 
+static void
+find_cols(AggState *aggstate, Bitmapset **aggregated_colnos, Bitmapset **unaggregated_colnos)
+{
+	Agg		   *node = (Agg *) aggstate->ss.ps.plan;
+
+	FindColsContext findColsContext;
+	findColsContext.aggregated_colnos = NULL;
+	findColsContext.unaggregated_colnos = NULL;
+	(void) find_unaggregated_cols_walker((Node *) node->plan.targetlist, &findColsContext);
+	(void) find_unaggregated_cols_walker((Node *) node->plan.qual, &findColsContext);
+	*aggregated_colnos = findColsContext.aggregated_colnos;
+	*unaggregated_colnos = findColsContext.unaggregated_colnos;
+}
+
+
+
 /*
  * (Re-)initialize the hash table(s) to empty.
  *
@@ -1318,8 +1347,11 @@ find_hash_columns(AggState *aggstate)
 	EState	   *estate = aggstate->ss.ps.state;
 	int			j;
 
+	Bitmapset *aggregated_colnos;
+	Bitmapset *unaggregated_colnos;
+	find_cols(aggstate, &aggregated_colnos, &unaggregated_colnos);
 	/* Find Vars that will be needed in tlist and qual */
-	base_colnos = find_unaggregated_cols(aggstate);
+	base_colnos = unaggregated_colnos;
 
 	for (j = 0; j < numHashes; ++j)
 	{
-- 
2.20.1 (Apple Git-117)

