commit 1534e7823b34fc6279a248b1ab59563ae79bead5
Author: Markus Winand <markus.winand@winand.at>
Date:   Thu Sep 7 18:12:24 2023 +0200

    Validate IMMUTABLE/STABLE declarations at CREATE time if possible
    
    SQL-standard function bodies are parsed at CREATE-time and can thus report
    syntax errors and the like. However, commands that contradict the IMMUTABLE
    or STABLE declarations such as INSERT or FOR UPDATE where still accepted.
    
    Also, error reporting for writeable CTEs in non-volatile functions or read-only
    transactions was quite misleading:
    
        SELECT is not allowed in a non-volatile function
    
    New error messages like the following are introduced:
    
        cannot modify table "x" in a non-volatile function

diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c
index 49c7864c7c..0466307215 100644
--- a/src/backend/commands/functioncmds.c
+++ b/src/backend/commands/functioncmds.c
@@ -56,6 +56,7 @@
 #include "executor/functions.h"
 #include "funcapi.h"
 #include "miscadmin.h"
+#include "nodes/nodeFuncs.h"
 #include "optimizer/optimizer.h"
 #include "parser/analyze.h"
 #include "parser/parse_coerce.h"
@@ -842,6 +843,28 @@ compute_function_attributes(ParseState *pstate,
 		*parallel_p = interpret_func_parallel(parallel_item);
 }
 
+static bool
+validate_volatility_walker(Node *node, void *ctx)
+{
+	if (node == NULL)
+		return false;
+	if (IsA(node, Query))
+	{
+		Query *q = (Query*)node;
+		if (q->hasForUpdate || q->commandType == CMD_UTILITY)
+				ereport(ERROR,
+						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+						 errmsg("%s is not allowed in a non-volatile function", CreateCommandName(node))));
+		if (q->resultRelation > 0)
+				ereport(ERROR,
+						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+						 errmsg("cannot modify table \"%s\" in a non-volatile function",
+						 ((RangeTblEntry *) list_nth(q->rtable, (q->resultRelation)-1)) ->eref->aliasname)));
+		return query_tree_walker(q, validate_volatility_walker, ctx, 0);
+	} else {
+		return expression_tree_walker(node, validate_volatility_walker, ctx);
+	}
+}
 
 /*
  * For a dynamically linked C language object, the form of the clause is
@@ -858,7 +881,7 @@ interpret_AS_clause(Oid languageOid, const char *languageName,
 					List *parameterTypes, List *inParameterNames,
 					char **prosrc_str_p, char **probin_str_p,
 					Node **sql_body_out,
-					const char *queryString)
+					const char *queryString, const char volatility)
 {
 	if (!sql_body_in && !as)
 		ereport(ERROR,
@@ -943,6 +966,8 @@ interpret_AS_clause(Oid languageOid, const char *languageName,
 							errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 							errmsg("%s is not yet supported in unquoted SQL function body",
 								   GetCommandTagName(CreateCommandTag(q->utilityStmt))));
+				if (volatility != PROVOLATILE_VOLATILE)
+					validate_volatility_walker((Node*)q, NULL);
 				transformed_stmts = lappend(transformed_stmts, q);
 				free_parsestate(pstate);
 			}
@@ -962,6 +987,8 @@ interpret_AS_clause(Oid languageOid, const char *languageName,
 						errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 						errmsg("%s is not yet supported in unquoted SQL function body",
 							   GetCommandTagName(CreateCommandTag(q->utilityStmt))));
+			if (volatility != PROVOLATILE_VOLATILE)
+				validate_volatility_walker((Node*)q, NULL);
 			free_parsestate(pstate);
 
 			*sql_body_out = (Node *) q;
@@ -1227,7 +1254,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 	interpret_AS_clause(languageOid, language, funcname, as_clause, stmt->sql_body,
 						parameterTypes_list, inParameterNames_list,
 						&prosrc_str, &probin_str, &prosqlbody,
-						pstate->p_sourcetext);
+						pstate->p_sourcetext, volatility);
 
 	/*
 	 * Set default values for COST and ROWS depending on other parameters;
diff --git a/src/backend/executor/execMain.c b/src/backend/executor/execMain.c
index 4c5a7bbf62..1ba43135ec 100644
--- a/src/backend/executor/execMain.c
+++ b/src/backend/executor/execMain.c
@@ -816,7 +816,7 @@ ExecCheckXactReadOnly(PlannedStmt *plannedstmt)
 		if (isTempNamespace(get_rel_namespace(perminfo->relid)))
 			continue;
 
-		PreventCommandIfReadOnly(CreateCommandName((Node *) plannedstmt));
+		PreventPlannedStmtIfReadOnly(plannedstmt);
 	}
 
 	if (plannedstmt->commandType != CMD_SELECT || plannedstmt->hasModifyingCTE)
diff --git a/src/backend/executor/functions.c b/src/backend/executor/functions.c
index f55424eb5a..5a0caa1abc 100644
--- a/src/backend/executor/functions.c
+++ b/src/backend/executor/functions.c
@@ -520,11 +520,7 @@ init_execution_state(List *queryTree_list,
 			}
 
 			if (fcache->readonly_func && !CommandIsReadOnly(stmt))
-				ereport(ERROR,
-						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
-				/* translator: %s is a SQL statement name */
-						 errmsg("%s is not allowed in a non-volatile function",
-								CreateCommandName((Node *) stmt))));
+				ReportNonVolatileViolation(ERRCODE_FEATURE_NOT_SUPPORTED, stmt);
 
 			/* OK, build the execution_state for this query */
 			newes = (execution_state *) palloc(sizeof(execution_state));
diff --git a/src/backend/executor/spi.c b/src/backend/executor/spi.c
index 33975687b3..190661a2bf 100644
--- a/src/backend/executor/spi.c
+++ b/src/backend/executor/spi.c
@@ -1736,11 +1736,7 @@ SPI_cursor_open_internal(const char *name, SPIPlanPtr plan,
 			PlannedStmt *pstmt = lfirst_node(PlannedStmt, lc);
 
 			if (!CommandIsReadOnly(pstmt))
-				ereport(ERROR,
-						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
-				/* translator: %s is a SQL statement name */
-						 errmsg("%s is not allowed in a non-volatile function",
-								CreateCommandName((Node *) pstmt))));
+				ReportReadOnlyViolation(ERRCODE_FEATURE_NOT_SUPPORTED, pstmt);
 		}
 	}
 
@@ -2629,11 +2625,7 @@ _SPI_execute_plan(SPIPlanPtr plan, const SPIExecuteOptions *options,
 			}
 
 			if (options->read_only && !CommandIsReadOnly(stmt))
-				ereport(ERROR,
-						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
-				/* translator: %s is a SQL statement name */
-						 errmsg("%s is not allowed in a non-volatile function",
-								CreateCommandName((Node *) stmt))));
+				ReportReadOnlyViolation(ERRCODE_FEATURE_NOT_SUPPORTED, stmt);
 
 			/*
 			 * If not read-only mode, advance the command counter before each
diff --git a/src/backend/tcop/utility.c b/src/backend/tcop/utility.c
index e3ccf6c7f7..51d8d44b04 100644
--- a/src/backend/tcop/utility.c
+++ b/src/backend/tcop/utility.c
@@ -406,6 +406,7 @@ ClassifyUtilityCommandAsReadOnly(Node *parsetree)
  *
  * This is useful partly to ensure consistency of the error message wording;
  * some callers have checked XactReadOnly for themselves.
+ * See also: PreventPlannedStmtIfReadOnly
  */
 void
 PreventCommandIfReadOnly(const char *cmdname)
@@ -418,6 +419,92 @@ PreventCommandIfReadOnly(const char *cmdname)
 						cmdname)));
 }
 
+/*
+ * GetModifyingTableName
+ *
+ * Utility to get the alias name of a table that requires a RowExclusiveLock.
+ * Used to get a helpful name for error messages when a writeable CTE appears
+ * in read-only context.
+ */
+const char *
+GetModifyingTableName(const PlannedStmt *pstmt)
+{
+	const ListCell *l;
+	foreach(l, pstmt->rtable)
+	{
+		RangeTblEntry *rte = lfirst(l);
+
+		if (rte->rtekind == RTE_RELATION && rte->rellockmode == RowExclusiveLock) {
+			return rte->eref->aliasname;
+		}
+	}
+	Assert(false);
+	return ""; /* silence compiler */
+}
+
+/*
+ * ReportReadOnlyViolation: throw error about modifying command in read-only context
+ *
+ * Takes care of the writeable CTE special case that the command tag is SELECT
+ * and thus misleading for a read-only error message.
+ */
+void
+ReportReadOnlyViolation(int sqlerrcode, const PlannedStmt *pstmt)
+{
+	if (CreateCommandTag((Node *) pstmt) == CMDTAG_SELECT)
+		ereport(ERROR,
+				(errcode(sqlerrcode),
+		/* translator: %s is a table name */
+				 errmsg("cannot modify table \"%s\" in a read-only transaction",
+						  GetModifyingTableName(pstmt))));
+	else
+		ereport(ERROR,
+				(errcode(sqlerrcode),
+		/* translator: %s is name of a SQL command, eg INSERT */
+				 errmsg("cannot execute %s in a read-only transaction",
+						  CreateCommandName((Node *) pstmt))));
+}
+
+
+/*
+ * ReportNonVolatileViolation: throw error about modifying command in non-volatile context
+ *
+ * Takes care of the writeable CTE special case that the command tag is SELECT
+ * and thus misleading for a read-only error message.
+ */
+void
+ReportNonVolatileViolation(int sqlerrcode, const PlannedStmt *pstmt)
+{
+	if (CreateCommandTag((Node *) pstmt) == CMDTAG_SELECT)
+		ereport(ERROR,
+				(errcode(sqlerrcode),
+		/* translator: %s is a table name */
+				 errmsg("cannot modify table \"%s\" in a non-volatile function",
+						  GetModifyingTableName(pstmt))));
+	else
+		ereport(ERROR,
+				(errcode(sqlerrcode),
+		/* translator: %s is name of a SQL command, eg INSERT */
+				 errmsg("%s is not allowed in a non-volatile function",
+						  CreateCommandName((Node *) pstmt))));
+}
+
+
+/*
+ * PreventPlannedStmtIfReadOnly: throw error if XactReadOnly
+ *
+ * Like PreventCommandIfReadOnly but takes care of the writeable CTE special
+ * case that the command tag is SELECT and thus inappropriate for a read-only
+ * error message.
+ */
+void
+PreventPlannedStmtIfReadOnly(const PlannedStmt *pstmt)
+{
+	if (XactReadOnly)
+		ReportReadOnlyViolation(ERRCODE_READ_ONLY_SQL_TRANSACTION, pstmt);
+}
+
+
 /*
  * PreventCommandIfParallelMode: throw error if current (sub)transaction is
  * in parallel mode.
diff --git a/src/include/tcop/utility.h b/src/include/tcop/utility.h
index 59e64aea07..1bf8a83559 100644
--- a/src/include/tcop/utility.h
+++ b/src/include/tcop/utility.h
@@ -109,4 +109,9 @@ extern LogStmtLevel GetCommandLogLevel(Node *parsetree);
 
 extern bool CommandIsReadOnly(PlannedStmt *pstmt);
 
+extern const char * GetModifyingTableName(const PlannedStmt *pstm);
+extern void ReportReadOnlyViolation(int sqlerrcode, const PlannedStmt *pstmt);
+extern void ReportNonVolatileViolation(int sqlerrcode, const PlannedStmt *pstmt);
+extern void PreventPlannedStmtIfReadOnly(const PlannedStmt *pstmt);
+
 #endif							/* UTILITY_H */
diff --git a/src/test/regress/expected/create_function_sql.out b/src/test/regress/expected/create_function_sql.out
index 50aca5940f..9dcb77c044 100644
--- a/src/test/regress/expected/create_function_sql.out
+++ b/src/test/regress/expected/create_function_sql.out
@@ -305,6 +305,16 @@ ERROR:  operator does not exist: date > integer
 LINE 3:     RETURN x > 1;
                      ^
 HINT:  No operator matches the given name and argument types. You might need to add explicit type casts.
+-- check tricky violation of STABLE declaration
+CREATE FUNCTION functest_S_xx(x int) RETURNS boolean STABLE
+    LANGUAGE SQL
+    BEGIN ATOMIC
+        WITH cte AS (
+            INSERT INTO functest1 SELECT x RETURNING *
+        )
+        SELECT i<0 FROM cte;
+    END;
+ERROR:  cannot modify table "functest1" in a non-volatile function
 -- tricky parsing
 CREATE FUNCTION functest_S_15(x int) RETURNS boolean
 LANGUAGE SQL
diff --git a/src/test/regress/sql/create_function_sql.sql b/src/test/regress/sql/create_function_sql.sql
index 89e9af3a49..fbbc18e53e 100644
--- a/src/test/regress/sql/create_function_sql.sql
+++ b/src/test/regress/sql/create_function_sql.sql
@@ -204,6 +204,16 @@ CREATE FUNCTION functest_S_xx(x date) RETURNS boolean
     LANGUAGE SQL
     RETURN x > 1;
 
+-- check tricky violation of STABLE declaration
+CREATE FUNCTION functest_S_xx(x int) RETURNS boolean STABLE
+    LANGUAGE SQL
+    BEGIN ATOMIC
+        WITH cte AS (
+            INSERT INTO functest1 SELECT x RETURNING *
+        )
+        SELECT i<0 FROM cte;
+    END;
+
 -- tricky parsing
 CREATE FUNCTION functest_S_15(x int) RETURNS boolean
 LANGUAGE SQL
