From 81b1d7cdb64e91a59670609510c08a55ced04095 Mon Sep 17 00:00:00 2001
From: Jacob Champion <jacob.champion@enterprisedb.com>
Date: Thu, 25 Apr 2024 15:26:40 -0700
Subject: [PATCH v3] jsonapi: add lexer option to keep token ownership

Commit 0785d1b8b adds support for libpq as a JSON client, but
allocations for string tokens can still be leaked during parsing
failures. This is tricky to fix for the object_field semantic callbacks:
the field name must remain valid until the end of the object, but if a
parsing error is encountered partway through, object_field_end() won't
be invoked and the client won't get a chance to free the field name.

At Andrew's suggestion, add a flag to switch the ownership of parsed
tokens to the lexer. When this is enabled, the client must make a copy
of any tokens it wants to persist past the callback lifetime, but the
lexer will handle necessary cleanup on failure.
---
 src/common/jsonapi.c         | 90 ++++++++++++++++++++++++++++++++----
 src/include/common/jsonapi.h | 28 +++++++++--
 2 files changed, 107 insertions(+), 11 deletions(-)

diff --git a/src/common/jsonapi.c b/src/common/jsonapi.c
index 45838d8a18..79f5ffa238 100644
--- a/src/common/jsonapi.c
+++ b/src/common/jsonapi.c
@@ -280,6 +280,7 @@ static JsonParseErrorType parse_array_element(JsonLexContext *lex, const JsonSem
 static JsonParseErrorType parse_array(JsonLexContext *lex, const JsonSemAction *sem);
 static JsonParseErrorType report_parse_error(JsonParseContext ctx, JsonLexContext *lex);
 static bool allocate_incremental_state(JsonLexContext *lex);
+static inline void set_fname(JsonLexContext *lex, char *fname);
 
 /* the null action object used for pure validation */
 const JsonSemAction nullSemAction =
@@ -437,7 +438,7 @@ allocate_incremental_state(JsonLexContext *lex)
 			   *fnull;
 
 	lex->inc_state = ALLOC0(sizeof(JsonIncrementalState));
-	pstack = ALLOC(sizeof(JsonParserStack));
+	pstack = ALLOC0(sizeof(JsonParserStack));
 	prediction = ALLOC(JS_STACK_CHUNK_SIZE * JS_MAX_PROD_LEN);
 	fnames = ALLOC(JS_STACK_CHUNK_SIZE * sizeof(char *));
 	fnull = ALLOC(JS_STACK_CHUNK_SIZE * sizeof(bool));
@@ -464,10 +465,17 @@ allocate_incremental_state(JsonLexContext *lex)
 	lex->pstack = pstack;
 	lex->pstack->stack_size = JS_STACK_CHUNK_SIZE;
 	lex->pstack->prediction = prediction;
-	lex->pstack->pred_index = 0;
 	lex->pstack->fnames = fnames;
 	lex->pstack->fnull = fnull;
 
+	/*
+	 * fnames between 0 and lex_level must always be defined so that
+	 * freeJsonLexContext() can handle them safely. inc/dec_lex_level() handle
+	 * the rest.
+	 */
+	Assert(lex->lex_level == 0);
+	lex->pstack->fnames[0] = NULL;
+
 	lex->incremental = true;
 	return true;
 }
@@ -530,6 +538,15 @@ makeJsonLexContextIncremental(JsonLexContext *lex, int encoding,
 	return lex;
 }
 
+void
+setJsonLexContextOwnsTokens(JsonLexContext *lex, bool owned_by_context)
+{
+	if (owned_by_context)
+		lex->flags |= JSONLEX_CTX_OWNS_TOKENS;
+	else
+		lex->flags &= ~JSONLEX_CTX_OWNS_TOKENS;
+}
+
 static inline bool
 inc_lex_level(JsonLexContext *lex)
 {
@@ -569,12 +586,23 @@ inc_lex_level(JsonLexContext *lex)
 	}
 
 	lex->lex_level += 1;
+
+	if (lex->incremental)
+	{
+		/*
+		 * Ensure freeJsonLexContext() remains safe even if no fname is assigned
+		 * at this level.
+		 */
+		lex->pstack->fnames[lex->lex_level] = NULL;
+	}
+
 	return true;
 }
 
 static inline void
 dec_lex_level(JsonLexContext *lex)
 {
+	set_fname(lex, NULL); /* free the current level's fname, if needed */
 	lex->lex_level -= 1;
 }
 
@@ -608,6 +636,15 @@ have_prediction(JsonParserStack *pstack)
 static inline void
 set_fname(JsonLexContext *lex, char *fname)
 {
+	if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+	{
+		/*
+		 * Don't leak prior fnames. If one hasn't been assigned yet,
+		 * inc_lex_level ensured that it's NULL (and therefore safe to free).
+		 */
+		FREE(lex->pstack->fnames[lex->lex_level]);
+	}
+
 	lex->pstack->fnames[lex->lex_level] = fname;
 }
 
@@ -655,8 +692,19 @@ freeJsonLexContext(JsonLexContext *lex)
 		jsonapi_termStringInfo(&lex->inc_state->partial_token);
 		FREE(lex->inc_state);
 		FREE(lex->pstack->prediction);
+
+		if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+		{
+			int			i;
+
+			/* Clean up any tokens that were left behind. */
+			for (i = 0; i <= lex->lex_level; i++)
+				FREE(lex->pstack->fnames[i]);
+		}
+
 		FREE(lex->pstack->fnames);
 		FREE(lex->pstack->fnull);
+		FREE(lex->pstack->scalar_val);
 		FREE(lex->pstack);
 	}
 
@@ -1086,6 +1134,17 @@ pg_parse_json_incremental(JsonLexContext *lex,
 						if (sfunc != NULL)
 						{
 							result = (*sfunc) (sem->semstate, pstack->scalar_val, pstack->scalar_tok);
+
+							/*
+							 * Either ownership of the token passed to the
+							 * callback, or we need to free it now. Either way,
+							 * clear our pointer to it so it doesn't get freed
+							 * in the future.
+							 */
+							if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+								FREE(pstack->scalar_val);
+							pstack->scalar_val = NULL;
+
 							if (result != JSON_SUCCESS)
 								return result;
 						}
@@ -1221,11 +1280,17 @@ parse_scalar(JsonLexContext *lex, const JsonSemAction *sem)
 	/* consume the token */
 	result = json_lex(lex);
 	if (result != JSON_SUCCESS)
+	{
+		FREE(val);
 		return result;
+	}
 
-	/* invoke the callback */
+	/* invoke the callback, which may take ownership of val */
 	result = (*sfunc) (sem->semstate, val, tok);
 
+	if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+		FREE(val);
+
 	return result;
 }
 
@@ -1238,7 +1303,7 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 	 * generally call a field name a "key".
 	 */
 
-	char	   *fname = NULL;	/* keep compiler quiet */
+	char	   *fname = NULL;
 	json_ofield_action ostart = sem->object_field_start;
 	json_ofield_action oend = sem->object_field_end;
 	bool		isnull;
@@ -1255,11 +1320,17 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 	}
 	result = json_lex(lex);
 	if (result != JSON_SUCCESS)
+	{
+		FREE(fname);
 		return result;
+	}
 
 	result = lex_expect(JSON_PARSE_OBJECT_LABEL, lex, JSON_TOKEN_COLON);
 	if (result != JSON_SUCCESS)
+	{
+		FREE(fname);
 		return result;
+	}
 
 	tok = lex_peek(lex);
 	isnull = tok == JSON_TOKEN_NULL;
@@ -1268,7 +1339,7 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 	{
 		result = (*ostart) (sem->semstate, fname, isnull);
 		if (result != JSON_SUCCESS)
-			return result;
+			goto ofield_cleanup;
 	}
 
 	switch (tok)
@@ -1283,16 +1354,19 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 			result = parse_scalar(lex, sem);
 	}
 	if (result != JSON_SUCCESS)
-		return result;
+		goto ofield_cleanup;
 
 	if (oend != NULL)
 	{
 		result = (*oend) (sem->semstate, fname, isnull);
 		if (result != JSON_SUCCESS)
-			return result;
+			goto ofield_cleanup;
 	}
 
-	return JSON_SUCCESS;
+ofield_cleanup:
+	if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+		FREE(fname);
+	return result;
 }
 
 static JsonParseErrorType
diff --git a/src/include/common/jsonapi.h b/src/include/common/jsonapi.h
index c524ff5be8..167615a557 100644
--- a/src/include/common/jsonapi.h
+++ b/src/include/common/jsonapi.h
@@ -92,9 +92,11 @@ typedef struct JsonIncrementalState JsonIncrementalState;
  * conjunction with token_start.
  *
  * JSONLEX_FREE_STRUCT/STRVAL are used to drive freeJsonLexContext.
+ * JSONLEX_CTX_OWNS_TOKENS is used by setJsonLexContextOwnsTokens.
  */
 #define JSONLEX_FREE_STRUCT			(1 << 0)
 #define JSONLEX_FREE_STRVAL			(1 << 1)
+#define JSONLEX_CTX_OWNS_TOKENS		(1 << 2)
 typedef struct JsonLexContext
 {
 	const char *input;
@@ -130,9 +132,10 @@ typedef JsonParseErrorType (*json_scalar_action) (void *state, char *token, Json
  * to doing a pure parse with no side-effects, and is therefore exactly
  * what the json input routines do.
  *
- * The 'fname' and 'token' strings passed to these actions are palloc'd.
- * They are not free'd or used further by the parser, so the action function
- * is free to do what it wishes with them.
+ * By default, the 'fname' and 'token' strings passed to these actions are
+ * palloc'd.  They are not free'd or used further by the parser, so the action
+ * function is free to do what it wishes with them. This behavior may be
+ * modified by setJsonLexContextOwnsTokens().
  *
  * All action functions return JsonParseErrorType.  If the result isn't
  * JSON_SUCCESS, the parse is abandoned and that error code is returned.
@@ -216,6 +219,25 @@ extern JsonLexContext *makeJsonLexContextIncremental(JsonLexContext *lex,
 													 int encoding,
 													 bool need_escapes);
 
+/*
+ * Sets whether tokens passed to semantic action callbacks are owned by the
+ * context (in which case, the callback must duplicate the tokens for long-term
+ * storage) or by the callback (in which case, the callback must explicitly
+ * free tokens to avoid leaks).
+ *
+ * By default, this setting is false: the callback owns the tokens that are
+ * passed to it (and if parsing fails between the two object-field callbacks,
+ * the field name token will likely leak). If set to true, tokens will be freed
+ * by the lexer after the callback completes.
+ *
+ * Setting this to true is important for long-lived clients (such as libpq)
+ * that must not leak memory during a parse failure. For a server backend using
+ * memory contexts, or a client application which will exit on parse failure,
+ * this setting is less critical.
+ */
+extern void setJsonLexContextOwnsTokens(JsonLexContext *lex,
+										bool owned_by_context);
+
 extern void freeJsonLexContext(JsonLexContext *lex);
 
 /* lex one token */
-- 
2.34.1

