From f838dd377c30778f530ec19bc004cba16a80d8e8 Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Tue, 15 Feb 2022 22:23:29 -0800
Subject: [PATCH 1/3] Add support for custom authentication methods

Currently, PostgreSQL supports only a set of pre-defined authentication
methods. This patch adds support for 2 hooks which allow users to add
their custom authentication methods by defining a check function and an
error function. Users can then use these methods by using a new "custom"
keyword in pg_hba.conf and specifying the authentication provider they
want to use.
---
 src/backend/libpq/auth.c | 85 ++++++++++++++++++++++++++++++----------
 src/backend/libpq/hba.c  | 36 +++++++++++++++++
 src/include/libpq/auth.h | 27 +++++++++++++
 src/include/libpq/hba.h  |  4 ++
 4 files changed, 131 insertions(+), 21 deletions(-)

diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index efc53f3135..2e3d02b35a 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -47,8 +47,6 @@
  *----------------------------------------------------------------
  */
 static void auth_failed(Port *port, int status, const char *logdetail);
-static char *recv_password_packet(Port *port);
-static void set_authn_id(Port *port, const char *id);
 
 
 /*----------------------------------------------------------------
@@ -206,23 +204,6 @@ static int	pg_SSPI_make_upn(char *accountname,
 static int	CheckRADIUSAuth(Port *port);
 static int	PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
 
-
-/*
- * Maximum accepted size of GSS and SSPI authentication tokens.
- * We also use this as a limit on ordinary password packet lengths.
- *
- * Kerberos tickets are usually quite small, but the TGTs issued by Windows
- * domain controllers include an authorization field known as the Privilege
- * Attribute Certificate (PAC), which contains the user's Windows permissions
- * (group memberships etc.). The PAC is copied into all tickets obtained on
- * the basis of this TGT (even those issued by Unix realms which the Windows
- * realm trusts), and can be several kB in size. The maximum token size
- * accepted by Windows systems is determined by the MaxAuthToken Windows
- * registry setting. Microsoft recommends that it is not set higher than
- * 65535 bytes, so that seems like a reasonable limit for us as well.
- */
-#define PG_MAX_AUTH_TOKEN_LENGTH	65535
-
 /*----------------------------------------------------------------
  * Global authentication functions
  *----------------------------------------------------------------
@@ -235,6 +216,16 @@ static int	PerformRadiusTransaction(const char *server, const char *secret, cons
  */
 ClientAuthentication_hook_type ClientAuthentication_hook = NULL;
 
+/*
+ * These hooks allow plugins to get control of the client authentication check
+ * and error reporting logic. This allows users to write extensions to
+ * implement authentication using any protocol of their choice. To acquire these
+ * hooks, plugins need to call the RegisterAuthProvider() function.
+ */
+static CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook = NULL;
+static CustomAuthenticationError_hook_type CustomAuthenticationError_hook = NULL;
+char *custom_provider_name = NULL;
+
 /*
  * Tell the user the authentication failed, but not (much about) why.
  *
@@ -311,6 +302,12 @@ auth_failed(Port *port, int status, const char *logdetail)
 		case uaRADIUS:
 			errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
 			break;
+		case uaCustom:
+			if (CustomAuthenticationError_hook)
+				errstr = CustomAuthenticationError_hook(port);
+			else
+				errstr = gettext_noop("Custom authentication failed for user \"%s\"");
+			break;
 		default:
 			errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
 			break;
@@ -345,7 +342,7 @@ auth_failed(Port *port, int status, const char *logdetail)
  * lifetime of the Port, so it is safe to pass a string that is managed by an
  * external library.
  */
-static void
+void
 set_authn_id(Port *port, const char *id)
 {
 	Assert(id);
@@ -630,6 +627,10 @@ ClientAuthentication(Port *port)
 		case uaTrust:
 			status = STATUS_OK;
 			break;
+		case uaCustom:
+			if (CustomAuthenticationCheck_hook)
+				status = CustomAuthenticationCheck_hook(port);
+			break;
 	}
 
 	if ((status == STATUS_OK && port->hba->clientcert == clientCertFull)
@@ -689,7 +690,7 @@ sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extrale
  *
  * Returns NULL if couldn't get password, else palloc'd string.
  */
-static char *
+char *
 recv_password_packet(Port *port)
 {
 	StringInfoData buf;
@@ -3343,3 +3344,45 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 		}
 	}							/* while (true) */
 }
+
+/*----------------------------------------------------------------
+ * Custom authentication
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RegisterAuthProvider registers a custom authentication provider to be
+ * used for authentication. Currently, we allow only one authentication
+ * provider to be registered for use at a time.
+ *
+ * This function should be called in _PG_init() by any extension looking to
+ * add a custom authentication method.
+ */
+void RegisterAuthProvider(const char *provider_name,
+		CustomAuthenticationCheck_hook_type AuthenticationCheckFunction,
+		CustomAuthenticationError_hook_type AuthenticationErrorFunction)
+{
+	if (provider_name == NULL)
+	{
+		ereport(ERROR,
+				(errmsg("cannot register authentication provider without name")));
+	}
+
+	if (AuthenticationCheckFunction == NULL)
+	{
+		ereport(ERROR,
+				(errmsg("cannot register authentication provider without a check function")));
+	}
+
+	if (custom_provider_name)
+	{
+		ereport(ERROR,
+				(errmsg("cannot register authentication provider %s", provider_name),
+				 errdetail("Only one authentication provider allowed.  Provider %s is already registered.",
+							custom_provider_name)));
+	}
+
+	custom_provider_name = pstrdup(provider_name);
+	CustomAuthenticationCheck_hook = AuthenticationCheckFunction;
+	CustomAuthenticationError_hook = AuthenticationErrorFunction;
+}
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index d84a40b726..956d7d6857 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -134,6 +134,7 @@ static const char *const UserAuthName[] =
 	"ldap",
 	"cert",
 	"radius",
+	"custom",
 	"peer"
 };
 
@@ -1399,6 +1400,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
 #endif
 	else if (strcmp(token->string, "radius") == 0)
 		parsedline->auth_method = uaRADIUS;
+	else if (strcmp(token->string, "custom") == 0)
+		parsedline->auth_method = uaCustom;
 	else
 	{
 		ereport(elevel,
@@ -1691,6 +1694,14 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
 		parsedline->clientcert = clientCertFull;
 	}
 
+	/*
+	 * Ensure that the provider name is specified for custom authentication method.
+	 */
+	if (parsedline->auth_method == uaCustom)
+	{
+		MANDATORY_AUTH_ARG(parsedline->custom_provider, "provider", "custom");
+	}
+
 	return parsedline;
 }
 
@@ -2102,6 +2113,31 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 		hbaline->radiusidentifiers = parsed_identifiers;
 		hbaline->radiusidentifiers_s = pstrdup(val);
 	}
+	else if (strcmp(name, "provider") == 0)
+	{
+		REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
+
+		/*
+		 * Verify that the provider mentioned is same as the one loaded
+		 * via shared_preload_libraries.
+		 */
+
+		if (custom_provider_name == NULL || strcmp(val,custom_provider_name) != 0)
+		{
+			ereport(elevel,
+					(errcode(ERRCODE_CONFIG_FILE_ERROR),
+					 errmsg("cannot use authentication provider %s",val),
+					 errhint("Load authentication provider via shared_preload_libraries."),
+					 errcontext("line %d of configuration file \"%s\"",
+							line_num, HbaFileName)));
+
+			return false;
+		}
+		else
+		{
+			hbaline->custom_provider = pstrdup(val);
+		}
+	}
 	else
 	{
 		ereport(elevel,
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 6d7ee1acb9..1d10cccc1b 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -23,9 +23,36 @@ extern char *pg_krb_realm;
 extern void ClientAuthentication(Port *port);
 extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
 							int extralen);
+extern void set_authn_id(Port *port, const char *id);
+extern char *recv_password_packet(Port *port);
 
 /* Hook for plugins to get control in ClientAuthentication() */
+typedef int (*CustomAuthenticationCheck_hook_type) (Port *);
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
 extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
 
+/* Hook for plugins to report error messages in auth_failed() */
+typedef const char * (*CustomAuthenticationError_hook_type) (Port *);
+
+extern void RegisterAuthProvider
+		(const char *provider_name,
+		 CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
+		 CustomAuthenticationError_hook_type CustomAuthenticationError_hook);
+
+/*
+ * Maximum accepted size of GSS and SSPI authentication tokens.
+ * We also use this as a limit on ordinary password packet lengths.
+ *
+ * Kerberos tickets are usually quite small, but the TGTs issued by Windows
+ * domain controllers include an authorization field known as the Privilege
+ * Attribute Certificate (PAC), which contains the user's Windows permissions
+ * (group memberships etc.). The PAC is copied into all tickets obtained on
+ * the basis of this TGT (even those issued by Unix realms which the Windows
+ * realm trusts), and can be several kB in size. The maximum token size
+ * accepted by Windows systems is determined by the MaxAuthToken Windows
+ * registry setting. Microsoft recommends that it is not set higher than
+ * 65535 bytes, so that seems like a reasonable limit for us as well.
+ */
+#define PG_MAX_AUTH_TOKEN_LENGTH	65535
+
 #endif							/* AUTH_H */
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 8d9f3821b1..c5aef6994c 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -38,6 +38,7 @@ typedef enum UserAuth
 	uaLDAP,
 	uaCert,
 	uaRADIUS,
+	uaCustom,
 	uaPeer
 #define USER_AUTH_LAST uaPeer	/* Must be last value of this enum */
 } UserAuth;
@@ -120,6 +121,7 @@ typedef struct HbaLine
 	char	   *radiusidentifiers_s;
 	List	   *radiusports;
 	char	   *radiusports_s;
+	char	   *custom_provider;
 } HbaLine;
 
 typedef struct IdentLine
@@ -144,4 +146,6 @@ extern int	check_usermap(const char *usermap_name,
 						  bool case_sensitive);
 extern bool pg_isblank(const char c);
 
+extern char *custom_provider_name;
+
 #endif							/* HBA_H */
-- 
2.34.1

