From dd690f7bf2a305768bce59fc5690316b1283bbee Mon Sep 17 00:00:00 2001
From: songjinzhou <2903807914@qq.com>
Date: Wed, 18 Jan 2023 17:49:33 +0800
Subject: [PATCH] Support plpgsql multi-range in conditional control

---
 doc/src/sgml/plpgsql.sgml                     |  15 +-
 .../plpgsql/src/expected/plpgsql_control.out  | 112 +++++++++++
 src/pl/plpgsql/src/pl_exec.c                  | 190 +++++++++---------
 src/pl/plpgsql/src/pl_funcs.c                 |  38 ++--
 src/pl/plpgsql/src/pl_gram.y                  | 123 +++++++++---
 src/pl/plpgsql/src/plpgsql.h                  |  13 +-
 src/pl/plpgsql/src/sql/plpgsql_control.sql    |  57 ++++++
 7 files changed, 406 insertions(+), 142 deletions(-)

diff --git a/doc/src/sgml/plpgsql.sgml b/doc/src/sgml/plpgsql.sgml
index 8897a5450a..edf523593a 100644
--- a/doc/src/sgml/plpgsql.sgml
+++ b/doc/src/sgml/plpgsql.sgml
@@ -2518,7 +2518,7 @@ END LOOP;
 
 <synopsis>
 <optional> &lt;&lt;<replaceable>label</replaceable>&gt;&gt; </optional>
-FOR <replaceable>name</replaceable> IN <optional> REVERSE </optional> <replaceable>expression</replaceable> .. <replaceable>expression</replaceable> <optional> BY <replaceable>expression</replaceable> </optional> LOOP
+FOR <replaceable>name</replaceable> IN <replaceable>condition_iterator</replaceable> <optional>, <replaceable>condition_iterator</replaceable> <optional> ... </optional></optional> LOOP
     <replaceable>statements</replaceable>
 END LOOP <optional> <replaceable>label</replaceable> </optional>;
 </synopsis>
@@ -2529,6 +2529,15 @@ END LOOP <optional> <replaceable>label</replaceable> </optional>;
         <replaceable>name</replaceable> is automatically defined as type
         <type>integer</type> and exists only inside the loop (any existing
         definition of the variable name is ignored within the loop).
+        </para>
+
+<synopsis>
+condition_iterator:
+<optional> REVERSE </optional> <replaceable>expression</replaceable> .. <replaceable>expression</replaceable> <optional> BY <replaceable>expression</replaceable> </optional>
+</synopsis>
+
+       <para>
+        Multiple iteration controls may be chained together by separating them with commas. The composition of each condition_iterator is as follows:
         The two expressions giving
         the lower and upper bound of the range are evaluated once when entering
         the loop. If the <literal>BY</literal> clause isn't specified the iteration
@@ -2552,6 +2561,10 @@ END LOOP;
 FOR i IN REVERSE 10..1 BY 2 LOOP
     -- i will take on the values 10,8,6,4,2 within the loop
 END LOOP;
+
+FOR I IN 1..10 BY 3, REVERSE I+10..I+1 BY 3 LOOP
+    -- i will take on the values 1,4,7,10,20,17,14,11 within the loop
+END LOOP;
 </programlisting>
        </para>
 
diff --git a/src/pl/plpgsql/src/expected/plpgsql_control.out b/src/pl/plpgsql/src/expected/plpgsql_control.out
index 328bd48586..d74602a04a 100644
--- a/src/pl/plpgsql/src/expected/plpgsql_control.out
+++ b/src/pl/plpgsql/src/expected/plpgsql_control.out
@@ -79,6 +79,118 @@ begin
 end$$;
 ERROR:  BY value of FOR loop must be greater than zero
 CONTEXT:  PL/pgSQL function inline_code_block line 3 at FOR with integer loop variable
+-- Test in condition list
+do $$
+declare
+	i int;
+begin
+	for i in 1..3 , 51..55 loop
+		raise notice '%', i;
+	end loop;
+
+  for i in 1..3 , reverse 55..51 loop
+		raise info '%', i;
+	end loop;
+	
+	for i in reverse 1..3 loop
+		raise notice '%', i;
+	end loop;
+	
+	for i in 1..3 loop
+		raise notice '%', i;
+	end loop;
+	
+	for i in reverse 3..1 loop
+		raise notice '%', i;
+	end loop;
+	
+	for i in 1..10 by 3 loop
+		raise notice '1..10 by 3: i = %', i;
+	end loop;
+end$$;
+NOTICE:  1
+NOTICE:  2
+NOTICE:  3
+NOTICE:  51
+NOTICE:  52
+NOTICE:  53
+NOTICE:  54
+NOTICE:  55
+INFO:  1
+INFO:  2
+INFO:  3
+INFO:  55
+INFO:  54
+INFO:  53
+INFO:  52
+INFO:  51
+NOTICE:  1
+NOTICE:  2
+NOTICE:  3
+NOTICE:  3
+NOTICE:  2
+NOTICE:  1
+NOTICE:  1..10 by 3: i = 1
+NOTICE:  1..10 by 3: i = 4
+NOTICE:  1..10 by 3: i = 7
+NOTICE:  1..10 by 3: i = 10
+do $$
+declare
+   i int := 10;
+begin
+   for i in reverse i+10..i+1 loop
+      raise info '%', i;
+   end loop;
+end $$;
+INFO:  20
+INFO:  19
+INFO:  18
+INFO:  17
+INFO:  16
+INFO:  15
+INFO:  14
+INFO:  13
+INFO:  12
+INFO:  11
+do $$
+declare
+   j int := 10;
+begin
+   for i in 1..3, reverse j+10..j+1 loop
+      raise info '%', i;
+   end loop;
+end $$;
+INFO:  1
+INFO:  2
+INFO:  3
+INFO:  20
+INFO:  19
+INFO:  18
+INFO:  17
+INFO:  16
+INFO:  15
+INFO:  14
+INFO:  13
+INFO:  12
+INFO:  11
+do $$
+declare
+   j int := 10;
+begin
+   for i in reverse j+10..j+1 loop
+      raise info '%', i;
+   end loop;
+end $$;
+INFO:  20
+INFO:  19
+INFO:  18
+INFO:  17
+INFO:  16
+INFO:  15
+INFO:  14
+INFO:  13
+INFO:  12
+INFO:  11
 -- CONTINUE statement
 create table conttesttbl(idx serial, v integer);
 insert into conttesttbl(v) values(10);
diff --git a/src/pl/plpgsql/src/pl_exec.c b/src/pl/plpgsql/src/pl_exec.c
index 37da624388..8460ac001b 100644
--- a/src/pl/plpgsql/src/pl_exec.c
+++ b/src/pl/plpgsql/src/pl_exec.c
@@ -2675,48 +2675,19 @@ exec_stmt_fori(PLpgSQL_execstate *estate, PLpgSQL_stmt_fori *stmt)
 	int32		step_value;
 	bool		found = false;
 	int			rc = PLPGSQL_RC_OK;
+	
+	ListCell	*lc;
 
 	var = (PLpgSQL_var *) (estate->datums[stmt->var->dno]);
 
-	/*
-	 * Get the value of the lower bound
-	 */
-	value = exec_eval_expr(estate, stmt->lower,
-						   &isnull, &valtype, &valtypmod);
-	value = exec_cast_value(estate, value, &isnull,
-							valtype, valtypmod,
-							var->datatype->typoid,
-							var->datatype->atttypmod);
-	if (isnull)
-		ereport(ERROR,
-				(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
-				 errmsg("lower bound of FOR loop cannot be null")));
-	loop_value = DatumGetInt32(value);
-	exec_eval_cleanup(estate);
-
-	/*
-	 * Get the value of the upper bound
-	 */
-	value = exec_eval_expr(estate, stmt->upper,
-						   &isnull, &valtype, &valtypmod);
-	value = exec_cast_value(estate, value, &isnull,
-							valtype, valtypmod,
-							var->datatype->typoid,
-							var->datatype->atttypmod);
-	if (isnull)
-		ereport(ERROR,
-				(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
-				 errmsg("upper bound of FOR loop cannot be null")));
-	end_value = DatumGetInt32(value);
-	exec_eval_cleanup(estate);
-
-	/*
-	 * Get the step value
-	 */
-	if (stmt->step)
+	foreach(lc, stmt->inlist)
 	{
-		value = exec_eval_expr(estate, stmt->step,
-							   &isnull, &valtype, &valtypmod);
+		PLpgSQL_fori_in_item *in_item = (PLpgSQL_fori_in_item *) lfirst(lc);
+		/*
+		* Get the value of the lower bound
+		*/
+		value = exec_eval_expr(estate, in_item->lower,
+							&isnull, &valtype, &valtypmod);
 		value = exec_cast_value(estate, value, &isnull,
 								valtype, valtypmod,
 								var->datatype->typoid,
@@ -2724,75 +2695,110 @@ exec_stmt_fori(PLpgSQL_execstate *estate, PLpgSQL_stmt_fori *stmt)
 		if (isnull)
 			ereport(ERROR,
 					(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
-					 errmsg("BY value of FOR loop cannot be null")));
-		step_value = DatumGetInt32(value);
+					errmsg("lower bound of FOR loop cannot be null")));
+		loop_value = DatumGetInt32(value);
 		exec_eval_cleanup(estate);
-		if (step_value <= 0)
+
+		/*
+		* Get the value of the upper bound
+		*/
+		value = exec_eval_expr(estate, in_item->upper,
+							&isnull, &valtype, &valtypmod);
+		value = exec_cast_value(estate, value, &isnull,
+								valtype, valtypmod,
+								var->datatype->typoid,
+								var->datatype->atttypmod);
+		if (isnull)
 			ereport(ERROR,
-					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-					 errmsg("BY value of FOR loop must be greater than zero")));
-	}
-	else
-		step_value = 1;
+					(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
+					errmsg("upper bound of FOR loop cannot be null")));
+		end_value = DatumGetInt32(value);
+		exec_eval_cleanup(estate);
 
-	/*
-	 * Now do the loop
-	 */
-	for (;;)
-	{
 		/*
-		 * Check against upper bound
-		 */
-		if (stmt->reverse)
+		* Get the step value
+		*/
+		if (in_item->step)
 		{
-			if (loop_value < end_value)
-				break;
+			value = exec_eval_expr(estate, in_item->step,
+								&isnull, &valtype, &valtypmod);
+			value = exec_cast_value(estate, value, &isnull,
+									valtype, valtypmod,
+									var->datatype->typoid,
+									var->datatype->atttypmod);
+			if (isnull)
+				ereport(ERROR,
+						(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
+						errmsg("BY value of FOR loop cannot be null")));
+			step_value = DatumGetInt32(value);
+			exec_eval_cleanup(estate);
+			if (step_value <= 0)
+				ereport(ERROR,
+						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+						errmsg("BY value of FOR loop must be greater than zero")));
 		}
 		else
+			step_value = 1;
+
+		/*
+		* Now do the loop
+		*/
+		for (;;)
 		{
-			if (loop_value > end_value)
-				break;
-		}
+			/*
+			* Check against upper bound
+			*/
+			if (in_item->reverse)
+			{
+				if (loop_value < end_value)
+					break;
+			}
+			else
+			{
+				if (loop_value > end_value)
+					break;
+			}
 
-		found = true;			/* looped at least once */
+			found = true;			/* looped at least once */
 
-		/*
-		 * Assign current value to loop var
-		 */
-		assign_simple_var(estate, var, Int32GetDatum(loop_value), false, false);
+			/*
+			* Assign current value to loop var
+			*/
+			assign_simple_var(estate, var, Int32GetDatum(loop_value), false, false);
 
-		/*
-		 * Execute the statements
-		 */
-		rc = exec_stmts(estate, stmt->body);
+			/*
+			* Execute the statements
+			*/
+			rc = exec_stmts(estate, stmt->body);
 
-		LOOP_RC_PROCESSING(stmt->label, break);
+			LOOP_RC_PROCESSING(stmt->label, break);
 
-		/*
-		 * Increase/decrease loop value, unless it would overflow, in which
-		 * case exit the loop.
-		 */
-		if (stmt->reverse)
-		{
-			if (loop_value < (PG_INT32_MIN + step_value))
-				break;
-			loop_value -= step_value;
-		}
-		else
-		{
-			if (loop_value > (PG_INT32_MAX - step_value))
-				break;
-			loop_value += step_value;
+			/*
+			* Increase/decrease loop value, unless it would overflow, in which
+			* case exit the loop.
+			*/
+			if (in_item->reverse)
+			{
+				if (loop_value < (PG_INT32_MIN + step_value))
+					break;
+				loop_value -= step_value;
+			}
+			else
+			{
+				if (loop_value > (PG_INT32_MAX - step_value))
+					break;
+				loop_value += step_value;
+			}
 		}
-	}
 
-	/*
-	 * Set the FOUND variable to indicate the result of executing the loop
-	 * (namely, whether we looped one or more times). This must be set here so
-	 * that it does not interfere with the value of the FOUND variable inside
-	 * the loop processing itself.
-	 */
-	exec_set_found(estate, found);
+		/*
+		* Set the FOUND variable to indicate the result of executing the loop
+		* (namely, whether we looped one or more times). This must be set here so
+		* that it does not interfere with the value of the FOUND variable inside
+		* the loop processing itself.
+		*/
+		exec_set_found(estate, found);
+	}
 
 	return rc;
 }
diff --git a/src/pl/plpgsql/src/pl_funcs.c b/src/pl/plpgsql/src/pl_funcs.c
index 5a6eadccd5..7f2c219993 100644
--- a/src/pl/plpgsql/src/pl_funcs.c
+++ b/src/pl/plpgsql/src/pl_funcs.c
@@ -546,9 +546,6 @@ free_while(PLpgSQL_stmt_while *stmt)
 static void
 free_fori(PLpgSQL_stmt_fori *stmt)
 {
-	free_expr(stmt->lower);
-	free_expr(stmt->upper);
-	free_expr(stmt->step);
 	free_stmts(stmt->body);
 }
 
@@ -1076,26 +1073,33 @@ dump_while(PLpgSQL_stmt_while *stmt)
 static void
 dump_fori(PLpgSQL_stmt_fori *stmt)
 {
-	dump_ind();
-	printf("FORI %s %s\n", stmt->var->refname, (stmt->reverse) ? "REVERSE" : "NORMAL");
+	ListCell   *lc;
 
-	dump_indent += 2;
 	dump_ind();
-	printf("    lower = ");
-	dump_expr(stmt->lower);
-	printf("\n");
-	dump_ind();
-	printf("    upper = ");
-	dump_expr(stmt->upper);
-	printf("\n");
-	if (stmt->step)
+	printf("FORI %s\n", stmt->var->refname);
+	foreach(lc, stmt->inlist)
 	{
+		PLpgSQL_fori_in_item *in_item = (PLpgSQL_fori_in_item *) lfirst(lc);
+
+		dump_indent += 2;
 		dump_ind();
-		printf("    step = ");
-		dump_expr(stmt->step);
+		printf("    %s", (in_item->reverse) ? "REVERSE" : "NORMAL");
+
+		printf(" lower = ");
+		dump_expr(in_item->lower);
+		printf(" ,");
+		printf(" upper = ");
+		dump_expr(in_item->upper);
+
+		if (in_item->step)
+		{
+			printf(" , ");
+			printf("step = ");
+			dump_expr(in_item->step);
+		}
 		printf("\n");
+		dump_indent -= 2;
 	}
-	dump_indent -= 2;
 
 	dump_stmts(stmt->body);
 
diff --git a/src/pl/plpgsql/src/pl_gram.y b/src/pl/plpgsql/src/pl_gram.y
index edeb72c380..8b34f48b88 100644
--- a/src/pl/plpgsql/src/pl_gram.y
+++ b/src/pl/plpgsql/src/pl_gram.y
@@ -74,6 +74,9 @@ static	PLpgSQL_expr	*read_sql_expression(int until,
 static	PLpgSQL_expr	*read_sql_expression2(int until, int until2,
 											  const char *expected,
 											  int *endtoken);
+static	PLpgSQL_expr	*read_sql_expression3(int until, int until2, int until3,
+											  const char *expected,
+											  int *endtoken);
 static	PLpgSQL_expr	*read_sql_stmt(void);
 static	PLpgSQL_type	*read_datatype(int tok);
 static	PLpgSQL_stmt	*make_execsql_stmt(int firsttoken, int location);
@@ -1434,6 +1437,7 @@ for_control		: for_variable K_IN
 							PLpgSQL_expr *expr1;
 							int			expr1loc;
 							bool		reverse = false;
+							bool		firstflag = true;
 
 							/*
 							 * We have to distinguish between two
@@ -1473,32 +1477,92 @@ for_control		: for_variable K_IN
 
 							if (tok == DOT_DOT)
 							{
-								/* Saw "..", so it must be an integer loop */
-								PLpgSQL_expr *expr2;
-								PLpgSQL_expr *expr_by;
 								PLpgSQL_var	*fvar;
 								PLpgSQL_stmt_fori *new;
 
-								/*
-								 * Relabel first expression as an expression;
-								 * then we can check its syntax.
-								 */
-								expr1->parseMode = RAW_PARSE_PLPGSQL_EXPR;
-								check_sql_expr(expr1->query, expr1->parseMode,
-											   expr1loc);
-
-								/* Read and check the second one */
-								expr2 = read_sql_expression2(K_LOOP, K_BY,
-															 "LOOP",
-															 &tok);
+								new = palloc0(sizeof(PLpgSQL_stmt_fori));
+								new->cmd_type = PLPGSQL_STMT_FORI;
 
-								/* Get the BY clause if any */
-								if (tok == K_BY)
-									expr_by = read_sql_expression(K_LOOP,
-																  "LOOP");
-								else
-									expr_by = NULL;
+								for (;;)
+								{
+									bool		reverseflag = false;
 
+									if(!firstflag)
+									{
+										if (tok_is_keyword(tok, &yylval,
+														K_REVERSE, "reverse"))
+											reverseflag = true;
+										else
+											plpgsql_push_back_token(tok);
+
+										/*
+										* We read the token again until we see ".." or LOOP,
+										* and likewise tell it not to check syntax.
+										*/
+										expr1 = read_sql_construct(DOT_DOT,
+																',',
+																K_LOOP,
+																", or loop",
+																RAW_PARSE_DEFAULT,
+																true,
+																false,
+																true,
+																&expr1loc,
+																&tok);
+									}
+
+									if (tok == DOT_DOT)
+									{
+										/* Saw "..", so it must be an integer loop */
+										PLpgSQL_expr *expr2;
+										PLpgSQL_expr *expr_by;
+										PLpgSQL_fori_in_item *in_item;
+
+										/*
+										* Relabel first expression as an expression;
+										* then we can check its syntax.
+										*/
+										expr1->parseMode = RAW_PARSE_PLPGSQL_EXPR;
+										check_sql_expr(expr1->query, expr1->parseMode,
+													expr1loc);
+
+										/* Read and check the second one */
+										expr2 = read_sql_expression3(K_BY, ',', K_LOOP,
+																	"by , or loop",
+																	&tok);
+
+										/* Get the BY clause if any */
+										if (tok == K_BY)
+											expr_by = read_sql_expression2(',', K_LOOP,
+																		", or K_LOOP", &tok);
+										else
+											expr_by = NULL;
+
+										in_item = palloc0(sizeof(PLpgSQL_fori_in_item));
+										in_item->reverse = firstflag ? reverse : reverseflag;
+										in_item->lower = expr1;
+										in_item->upper = expr2;
+										in_item->step = expr_by;
+										firstflag = false;
+
+										new->inlist = lappend(new->inlist, in_item);
+
+										/* check for in condition list */
+										if (tok == ',')
+										{
+											tok = yylex();
+											continue;
+										}
+										else if(tok == K_LOOP)
+										{
+											break;
+										}
+										else
+										{
+											yyerror("syntax error");
+										}
+									}
+								}
 								/* Should have had a single variable name */
 								if ($1.scalar && $1.row)
 									ereport(ERROR,
@@ -1515,15 +1579,8 @@ for_control		: for_variable K_IN
 																				  InvalidOid,
 																				  NULL),
 														   true);
-
-								new = palloc0(sizeof(PLpgSQL_stmt_fori));
-								new->cmd_type = PLPGSQL_STMT_FORI;
 								new->stmtid	= ++plpgsql_curr_compile->nstatements;
 								new->var = fvar;
-								new->reverse = reverse;
-								new->lower = expr1;
-								new->upper = expr2;
-								new->step = expr_by;
 
 								$$ = (PLpgSQL_stmt *) new;
 							}
@@ -2645,6 +2702,16 @@ read_sql_expression2(int until, int until2, const char *expected,
 							  true, true, true, NULL, endtoken);
 }
 
+/* Convenience routine to read an expression with three possible terminators */
+static PLpgSQL_expr *
+read_sql_expression3(int until, int until2, int until3, const char *expected,
+					 int *endtoken)
+{
+	return read_sql_construct(until, until2, until3, expected,
+							  RAW_PARSE_PLPGSQL_EXPR,
+							  true, true, true, NULL, endtoken);
+}
+
 /* Convenience routine to read a SQL statement that must end with ';' */
 static PLpgSQL_expr *
 read_sql_stmt(void)
diff --git a/src/pl/plpgsql/src/plpgsql.h b/src/pl/plpgsql/src/plpgsql.h
index 355c9f678d..ab35e8fd72 100644
--- a/src/pl/plpgsql/src/plpgsql.h
+++ b/src/pl/plpgsql/src/plpgsql.h
@@ -660,6 +660,14 @@ typedef struct PLpgSQL_stmt_while
 	List	   *body;			/* List of statements */
 } PLpgSQL_stmt_while;
 
+typedef struct PLpgSQL_fori_in_item
+{
+	PLpgSQL_expr *lower;
+	PLpgSQL_expr *upper;
+	PLpgSQL_expr *step;			/* NULL means default (ie, BY 1) */
+	int			reverse;
+}PLpgSQL_fori_in_item;
+
 /*
  * FOR statement with integer loopvar
  */
@@ -670,10 +678,7 @@ typedef struct PLpgSQL_stmt_fori
 	unsigned int stmtid;
 	char	   *label;
 	PLpgSQL_var *var;
-	PLpgSQL_expr *lower;
-	PLpgSQL_expr *upper;
-	PLpgSQL_expr *step;			/* NULL means default (ie, BY 1) */
-	int			reverse;
+	List	   *inlist;			/* List of in conditions */
 	List	   *body;			/* List of statements */
 } PLpgSQL_stmt_fori;
 
diff --git a/src/pl/plpgsql/src/sql/plpgsql_control.sql b/src/pl/plpgsql/src/sql/plpgsql_control.sql
index ed7231134f..b60097b6e3 100644
--- a/src/pl/plpgsql/src/sql/plpgsql_control.sql
+++ b/src/pl/plpgsql/src/sql/plpgsql_control.sql
@@ -58,6 +58,63 @@ begin
   end loop;
 end$$;
 
+-- Test in condition list
+
+do $$
+declare
+	i int;
+begin
+	for i in 1..3 , 51..55 loop
+		raise notice '%', i;
+	end loop;
+
+  for i in 1..3 , reverse 55..51 loop
+		raise info '%', i;
+	end loop;
+	
+	for i in reverse 1..3 loop
+		raise notice '%', i;
+	end loop;
+	
+	for i in 1..3 loop
+		raise notice '%', i;
+	end loop;
+	
+	for i in reverse 3..1 loop
+		raise notice '%', i;
+	end loop;
+	
+	for i in 1..10 by 3 loop
+		raise notice '1..10 by 3: i = %', i;
+	end loop;
+end$$;
+
+do $$
+declare
+   i int := 10;
+begin
+   for i in reverse i+10..i+1 loop
+      raise info '%', i;
+   end loop;
+end $$;
+
+do $$
+declare
+   j int := 10;
+begin
+   for i in 1..3, reverse j+10..j+1 loop
+      raise info '%', i;
+   end loop;
+end $$;
+
+do $$
+declare
+   j int := 10;
+begin
+   for i in reverse j+10..j+1 loop
+      raise info '%', i;
+   end loop;
+end $$;
 
 -- CONTINUE statement
 
-- 
2.36.1.windows.1

