From 8a4eee9621bd75e950499c9590560e55e55f96aa Mon Sep 17 00:00:00 2001
From: Dave Cramer <davecramer@gmail.com>
Date: Mon, 13 Mar 2023 16:21:40 -0400
Subject: [PATCH] Add a GUC format_binary, takes a comma separated list of
 OID's that the backend should always send as binary This allows the client to
 request binary format without having to do an extra round trip to request
 binary format if they are using extended protocol. Also works in simple
 query. JDBC for one requires no changes to the driver to accept data in
 binary format for well known types.

---
 src/backend/commands/variable.c     | 76 +++++++++++++++++++++++++++++
 src/backend/tcop/pquery.c           | 59 +++++++++++++++++++---
 src/backend/utils/init/globals.c    |  1 +
 src/backend/utils/misc/guc_tables.c | 13 ++++-
 src/include/miscadmin.h             |  1 +
 src/include/tcop/pquery.h           |  1 +
 src/include/utils/guc_hooks.h       |  2 +
 7 files changed, 145 insertions(+), 8 deletions(-)

diff --git a/src/backend/commands/variable.c b/src/backend/commands/variable.c
index f0f2e07655..efee3ac3e1 100644
--- a/src/backend/commands/variable.c
+++ b/src/backend/commands/variable.c
@@ -30,11 +30,13 @@
 #include "postmaster/postmaster.h"
 #include "postmaster/syslogger.h"
 #include "storage/bufmgr.h"
+#include "tcop/pquery.h"
 #include "utils/acl.h"
 #include "utils/backend_status.h"
 #include "utils/builtins.h"
 #include "utils/datetime.h"
 #include "utils/guc_hooks.h"
+#include "utils/memutils.h"
 #include "utils/snapmgr.h"
 #include "utils/syscache.h"
 #include "utils/timestamp.h"
@@ -1172,6 +1174,80 @@ check_default_with_oids(bool *newval, void **extra, GucSource source)
 	return true;
 }
 
+bool
+check_format_binary( char **newval, void **extra, GucSource source)
+{
+	// sanity check
+	if (*newval == NULL)
+		return false;
+
+	if (strcmp(*newval,"") == 0)
+		return true;
+
+	char *tmp = palloc(strlen(*newval));
+	strcpy(tmp, *newval);
+	char *token = strtok(tmp, ",");
+
+	while(token != NULL)
+	{
+		Oid candidate = atooid(token);
+		if (candidate > OID_MAX)
+			GUC_check_errdetail("OID out of range found in %s, %s", *newval, token);
+
+		// atooid will return 0 aka InvalidOid if it can't convert the string or 0
+		// if it's really 0
+		if (candidate == InvalidOid)
+		{
+			if (errno == EINVAL)
+				GUC_check_errdetail("%s has invalid characters at %s",
+					*newval, token);
+			else
+				GUC_check_errdetail("InvalidOid (0) found in %s", *newval);
+			return false;
+		}
+		else
+			token = strtok(NULL, ",");
+	}
+	return true;
+}
+
+void
+assign_format_binary(const char *newval, void *extra)
+{
+	// check for errors or nothing to do
+	if (newval == NULL || strcmp(newval, "") == 0)
+		return;
+
+	char *tmp = palloc(strlen(newval));
+	strcpy(tmp, newval);
+
+	/* Must save OID list in permanent storage. */
+	MemoryContext oldcxt = MemoryContextSwitchTo(TopMemoryContext);
+
+	// unlikely to have more than 16
+	int length = 16;
+	// +1 for the InvalidOid marker at the end
+	Oid *tmpOids = palloc(sizeof(Oid)*(length+1));
+	int i = 0;
+
+	char *token = strtok(tmp, ",");
+
+	while(token != NULL)
+	{
+		tmpOids[i++] = atooid(token);
+		if (i > length)
+		{
+			length += 16;
+			tmpOids = repalloc(tmpOids, sizeof(Oid)*(length+1));
+		}
+		token = strtok(NULL, ",");
+	}
+	tmpOids[i] = InvalidOid;
+	binary_format_oids = tmpOids;
+	MemoryContextSwitchTo(oldcxt);
+
+}
+
 bool
 check_effective_io_concurrency(int *newval, void **extra, GucSource source)
 {
diff --git a/src/backend/tcop/pquery.c b/src/backend/tcop/pquery.c
index 5f0248acc5..dd0afc6674 100644
--- a/src/backend/tcop/pquery.c
+++ b/src/backend/tcop/pquery.c
@@ -34,6 +34,13 @@
  */
 Portal		ActivePortal = NULL;
 
+/*
+* binary_format_oids is an array of oids that the client has requested to be
+* output in binary format
+*/
+Oid  *binary_format_oids;
+
+
 
 static void ProcessQuery(PlannedStmt *plan,
 						 const char *sourceText,
@@ -58,7 +65,7 @@ static uint64 DoPortalRunFetch(Portal portal,
 							   long count,
 							   DestReceiver *dest);
 static void DoPortalRewind(Portal portal);
-
+static int findOid( Oid *binary_oids, Oid oid);
 
 /*
  * CreateQueryDesc
@@ -644,20 +651,58 @@ PortalSetResultFormat(Portal portal, int nFormats, int16 *formats)
 	}
 	else if (nFormats > 0)
 	{
-		/* single format specified, use for all columns */
-		int16		format1 = formats[0];
+		// The client has requested binary formats for some types
+		if ( binary_format_oids != NULL )
+		{
+			Oid targetOid;
+			for (i = 0; i < natts; i++){
+				targetOid = portal->tupDesc->attrs[i].atttypid;
+				portal->formats[i] = findOid(binary_format_oids, targetOid);
+			}
+		}
+		else
+		{
+			/* single format specified, use for all columns */
+			int16		format1 = formats[0];
 
-		for (i = 0; i < natts; i++)
-			portal->formats[i] = format1;
+			for (i = 0; i < natts; i++)
+				portal->formats[i] = format1;
+		}
 	}
 	else
 	{
 		/* use default format for all columns */
-		for (i = 0; i < natts; i++)
-			portal->formats[i] = 0;
+		if ( binary_format_oids != NULL )
+		{
+			Oid targetOid;
+
+			for (i = 0; i < natts; i++){
+				targetOid = portal->tupDesc->attrs[i].atttypid;
+				portal->formats[i] = findOid(binary_format_oids, targetOid);
+			}
+		}
+		else {
+			/* use default format for all columns */
+			for (i = 0; i < natts; i++)
+				portal->formats[i] = 0;
+		}
 	}
 }
 
+/*
+* Linear search through the array of oids.
+* I don't expect this to ever be a large array
+*/
+static int findOid( Oid *binary_oids, Oid oid)
+{
+	Oid *tmp = binary_oids;
+	while (tmp && *tmp != InvalidOid)
+	{
+		if (*tmp++ == oid) return 1;
+	}
+	return 0;
+ }
+
 /*
  * PortalRun
  *		Run a portal's query or queries.
diff --git a/src/backend/utils/init/globals.c b/src/backend/utils/init/globals.c
index 1b1d814254..545d29810b 100644
--- a/src/backend/utils/init/globals.c
+++ b/src/backend/utils/init/globals.c
@@ -123,6 +123,7 @@ int			IntervalStyle = INTSTYLE_POSTGRES;
 bool		enableFsync = true;
 bool		allowSystemTableMods = false;
 int			work_mem = 4096;
+char        *format_binary = NULL;
 double		hash_mem_multiplier = 2.0;
 int			maintenance_work_mem = 65536;
 int			max_parallel_maintenance_workers = 2;
diff --git a/src/backend/utils/misc/guc_tables.c b/src/backend/utils/misc/guc_tables.c
index 1c0583fe26..fbbc2accd8 100644
--- a/src/backend/utils/misc/guc_tables.c
+++ b/src/backend/utils/misc/guc_tables.c
@@ -69,6 +69,7 @@
 #include "storage/predicate.h"
 #include "storage/standby.h"
 #include "tcop/tcopprot.h"
+#include "tcop/pquery.h"
 #include "tsearch/ts_cache.h"
 #include "utils/builtins.h"
 #include "utils/bytea.h"
@@ -4104,7 +4105,17 @@ struct config_string ConfigureNamesString[] =
 		"",
 		NULL, NULL, NULL
 	},
-
+	{
+		{"format_binary", PGC_USERSET, CLIENT_CONN_STATEMENT,
+			gettext_noop("Sets the type Oid's to be returned in binary format"),
+			gettext_noop("Set by the client to indicate which types are to be "
+						 "returned in binary format. "),
+			GUC_NOT_IN_SAMPLE | GUC_DISALLOW_IN_FILE
+		},
+		&format_binary,
+		"",
+		check_format_binary, assign_format_binary, NULL
+	},
 	{
 		{"search_path", PGC_USERSET, CLIENT_CONN_STATEMENT,
 			gettext_noop("Sets the schema search order for names that are not schema-qualified."),
diff --git a/src/include/miscadmin.h b/src/include/miscadmin.h
index 06a86f9ac1..62db4628c5 100644
--- a/src/include/miscadmin.h
+++ b/src/include/miscadmin.h
@@ -294,6 +294,7 @@ extern void PreventCommandDuringRecovery(const char *cmdname);
 /* in utils/misc/guc_tables.c */
 extern PGDLLIMPORT int trace_recovery_messages;
 extern int	trace_recovery(int trace_level);
+extern PGDLLIMPORT char *format_binary;
 
 /*****************************************************************************
  *	  pdir.h --																 *
diff --git a/src/include/tcop/pquery.h b/src/include/tcop/pquery.h
index a5e65b98aa..5853c796b6 100644
--- a/src/include/tcop/pquery.h
+++ b/src/include/tcop/pquery.h
@@ -22,6 +22,7 @@ struct PlannedStmt;				/* avoid including plannodes.h here */
 
 extern PGDLLIMPORT Portal ActivePortal;
 
+extern PGDLLIMPORT Oid  *binary_format_oids;
 
 extern PortalStrategy ChoosePortalStrategy(List *stmts);
 
diff --git a/src/include/utils/guc_hooks.h b/src/include/utils/guc_hooks.h
index aeb3663071..b7c62bf3bc 100644
--- a/src/include/utils/guc_hooks.h
+++ b/src/include/utils/guc_hooks.h
@@ -57,6 +57,8 @@ extern bool check_default_with_oids(bool *newval, void **extra,
 									GucSource source);
 extern bool check_effective_io_concurrency(int *newval, void **extra,
 										   GucSource source);
+bool check_format_binary( char **newval, void **extra, GucSource source);
+void assign_format_binary(const char *newval, void *extra);
 extern bool check_huge_page_size(int *newval, void **extra, GucSource source);
 extern const char *show_in_hot_standby(void);
 extern bool check_locale_messages(char **newval, void **extra, GucSource source);
-- 
2.37.1 (Apple Git-137.1)

