This is an automated email from the ASF dual-hosted git repository. siddteotia pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 70c4c5b801 [multistage] Table level Access Validation, QPS Quota, Phase Metrics for multistage queries (#10534) 70c4c5b801 is described below commit 70c4c5b8019b3c3b71865bde71e4899c3659082e Author: Vivek Iyer Vaidyanathan <vviveki...@gmail.com> AuthorDate: Mon Apr 10 11:24:36 2023 -0700 [multistage] Table level Access Validation, QPS Quota, Phase Metrics for multistage queries (#10534) * Table level ACL for multistage queries * Fix stylecheck isssues * Address review comments * Add QPS Quotas and phase timings for all tables in a multistage query * Address review comments --- .../org/apache/pinot/broker/api/AccessControl.java | 11 +++ .../broker/AllowAllAccessControlFactory.java | 6 ++ .../broker/BasicAuthAccessControlFactory.java | 45 +++++++++-- .../broker/ZkBasicAuthAccessControlFactory.java | 60 +++++++++----- .../MultiStageBrokerRequestHandler.java | 91 ++++++++++++++++++++-- .../broker/broker/BasicAuthAccessControlTest.java | 35 ++++++++- .../org/apache/pinot/query/QueryEnvironment.java | 44 +++++++++-- 7 files changed, 247 insertions(+), 45 deletions(-) diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java b/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java index c8e252ee27..485131a3e5 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.broker.api; +import java.util.Set; import org.apache.pinot.common.request.BrokerRequest; import org.apache.pinot.spi.annotations.InterfaceAudience; import org.apache.pinot.spi.annotations.InterfaceStability; @@ -47,4 +48,14 @@ public interface AccessControl { * @return {@code true} if authorized, {@code false} otherwise */ boolean hasAccess(RequesterIdentity requesterIdentity, BrokerRequest brokerRequest); + + /** + * Fine-grained access control on pinot tables. + * + * @param requesterIdentity requester identity + * @param tables Set of pinot tables used in the query. Table name can be with or without tableType. + * + * @return {@code true} if authorized, {@code false} otherwise + */ + boolean hasAccess(RequesterIdentity requesterIdentity, Set<String> tables); } diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/AllowAllAccessControlFactory.java b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/AllowAllAccessControlFactory.java index e5d96a424f..1e5a888a66 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/AllowAllAccessControlFactory.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/AllowAllAccessControlFactory.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.broker.broker; +import java.util.Set; import org.apache.pinot.broker.api.AccessControl; import org.apache.pinot.broker.api.RequesterIdentity; import org.apache.pinot.common.request.BrokerRequest; @@ -43,5 +44,10 @@ public class AllowAllAccessControlFactory extends AccessControlFactory { public boolean hasAccess(RequesterIdentity requesterIdentity, BrokerRequest brokerRequest) { return true; } + + @Override + public boolean hasAccess(RequesterIdentity requesterIdentity, Set<String> tables) { + return true; + } } } diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java index 91ae183e8c..1eb134cfa7 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java @@ -23,6 +23,7 @@ import java.util.Collection; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import org.apache.pinot.broker.api.AccessControl; import org.apache.pinot.broker.api.HttpRequesterIdentity; @@ -77,18 +78,12 @@ public class BasicAuthAccessControlFactory extends AccessControlFactory { @Override public boolean hasAccess(RequesterIdentity requesterIdentity) { - return hasAccess(requesterIdentity, null); + return hasAccess(requesterIdentity, (BrokerRequest) null); } @Override public boolean hasAccess(RequesterIdentity requesterIdentity, BrokerRequest brokerRequest) { - Preconditions.checkArgument(requesterIdentity instanceof HttpRequesterIdentity, "HttpRequesterIdentity required"); - HttpRequesterIdentity identity = (HttpRequesterIdentity) requesterIdentity; - - Collection<String> tokens = identity.getHttpHeaders().get(HEADER_AUTHORIZATION); - Optional<BasicAuthPrincipal> principalOpt = - tokens.stream().map(BasicAuthUtils::normalizeBase64Token).map(_token2principal::get).filter(Objects::nonNull) - .findFirst(); + Optional<BasicAuthPrincipal> principalOpt = getPrincipalOpt(requesterIdentity); if (!principalOpt.isPresent()) { // no matching token? reject @@ -104,5 +99,39 @@ public class BasicAuthAccessControlFactory extends AccessControlFactory { return principal.hasTable(brokerRequest.getQuerySource().getTableName()); } + + @Override + public boolean hasAccess(RequesterIdentity requesterIdentity, Set<String> tables) { + Optional<BasicAuthPrincipal> principalOpt = getPrincipalOpt(requesterIdentity); + + if (!principalOpt.isPresent()) { + // no matching token? reject + return false; + } + + if (tables == null || tables.isEmpty()) { + return true; + } + + BasicAuthPrincipal principal = principalOpt.get(); + for (String table : tables) { + if (!principal.hasTable(table)) { + return false; + } + } + + return true; + } + + private Optional<BasicAuthPrincipal> getPrincipalOpt(RequesterIdentity requesterIdentity) { + Preconditions.checkArgument(requesterIdentity instanceof HttpRequesterIdentity, "HttpRequesterIdentity required"); + HttpRequesterIdentity identity = (HttpRequesterIdentity) requesterIdentity; + + Collection<String> tokens = identity.getHttpHeaders().get(HEADER_AUTHORIZATION); + Optional<BasicAuthPrincipal> principalOpt = + tokens.stream().map(BasicAuthUtils::normalizeBase64Token).map(_token2principal::get).filter(Objects::nonNull) + .findFirst(); + return principalOpt; + } } } diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/ZkBasicAuthAccessControlFactory.java b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/ZkBasicAuthAccessControlFactory.java index a127f1a40a..9541492818 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/ZkBasicAuthAccessControlFactory.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/ZkBasicAuthAccessControlFactory.java @@ -20,9 +20,11 @@ package org.apache.pinot.broker.broker; import com.google.common.base.Preconditions; import java.util.Collection; +import java.util.Collections; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import org.apache.helix.store.zk.ZkHelixPropertyStore; import org.apache.helix.zookeeper.datamodel.ZNRecord; @@ -79,11 +81,42 @@ public class ZkBasicAuthAccessControlFactory extends AccessControlFactory { @Override public boolean hasAccess(RequesterIdentity requesterIdentity) { - return hasAccess(requesterIdentity, null); + return hasAccess(requesterIdentity, (BrokerRequest) null); } @Override public boolean hasAccess(RequesterIdentity requesterIdentity, BrokerRequest brokerRequest) { + if (brokerRequest == null || !brokerRequest.isSetQuerySource() || !brokerRequest.getQuerySource() + .isSetTableName()) { + // no table restrictions? accept + return true; + } + + return hasAccess(requesterIdentity, Collections.singleton(brokerRequest.getQuerySource().getTableName())); + } + + @Override + public boolean hasAccess(RequesterIdentity requesterIdentity, Set<String> tables) { + Optional<ZkBasicAuthPrincipal> principalOpt = getPrincipalAuth(requesterIdentity); + if (!principalOpt.isPresent()) { + // no matching token? reject + return false; + } + if (tables == null || tables.isEmpty()) { + return true; + } + + ZkBasicAuthPrincipal principal = principalOpt.get(); + for (String table : tables) { + if (!principal.hasTable(table)) { + return false; + } + } + + return true; + } + + private Optional<ZkBasicAuthPrincipal> getPrincipalAuth(RequesterIdentity requesterIdentity) { Preconditions.checkArgument(requesterIdentity instanceof HttpRequesterIdentity, "HttpRequesterIdentity required"); HttpRequesterIdentity identity = (HttpRequesterIdentity) requesterIdentity; @@ -95,28 +128,15 @@ public class ZkBasicAuthAccessControlFactory extends AccessControlFactory { Map<String, String> name2password = tokens.stream().collect(Collectors - .toMap(BasicAuthUtils::extractUsername, BasicAuthUtils::extractPassword)); + .toMap(BasicAuthUtils::extractUsername, BasicAuthUtils::extractPassword)); Map<String, ZkBasicAuthPrincipal> password2principal = name2password.keySet().stream() - .collect(Collectors.toMap(name2password::get, _name2principal::get)); + .collect(Collectors.toMap(name2password::get, _name2principal::get)); Optional<ZkBasicAuthPrincipal> principalOpt = - password2principal.entrySet().stream() - .filter(entry -> BcryptUtils.checkpw(entry.getKey(), entry.getValue().getPassword())) - .map(u -> u.getValue()).filter(Objects::nonNull).findFirst(); - - if (!principalOpt.isPresent()) { - // no matching token? reject - return false; - } - - ZkBasicAuthPrincipal principal = principalOpt.get(); - if (brokerRequest == null || !brokerRequest.isSetQuerySource() || !brokerRequest.getQuerySource() - .isSetTableName()) { - // no table restrictions? accept - return true; - } - - return principal.hasTable(brokerRequest.getQuerySource().getTableName()); + password2principal.entrySet().stream() + .filter(entry -> BcryptUtils.checkpw(entry.getKey(), entry.getValue().getPassword())) + .map(u -> u.getValue()).filter(Objects::nonNull).findFirst(); + return principalOpt; } } } diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java index a1686aa47a..ebb7b7dc3e 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java @@ -21,11 +21,15 @@ package org.apache.pinot.broker.requesthandler; import com.fasterxml.jackson.databind.JsonNode; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.apache.calcite.jdbc.CalciteSchemaBuilder; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; import org.apache.commons.lang3.StringUtils; import org.apache.pinot.broker.api.RequesterIdentity; import org.apache.pinot.broker.broker.AccessControlFactory; @@ -35,6 +39,7 @@ import org.apache.pinot.common.config.provider.TableCache; import org.apache.pinot.common.exception.QueryException; import org.apache.pinot.common.metrics.BrokerMeter; import org.apache.pinot.common.metrics.BrokerMetrics; +import org.apache.pinot.common.metrics.BrokerQueryPhase; import org.apache.pinot.common.request.BrokerRequest; import org.apache.pinot.common.response.BrokerResponse; import org.apache.pinot.common.response.broker.BrokerResponseNative; @@ -146,7 +151,7 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { long compilationStartTimeNs; long queryTimeoutMs; - QueryPlan queryPlan; + QueryEnvironment.QueryPlannerResult queryPlanResult; try { // Parse the request sqlNodeAndOptions = sqlNodeAndOptions != null ? sqlNodeAndOptions : RequestUtils.parseQuery(query, request); @@ -156,11 +161,18 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { compilationStartTimeNs = System.nanoTime(); switch (sqlNodeAndOptions.getSqlNode().getKind()) { case EXPLAIN: - String plan = _queryEnvironment.explainQuery(query, sqlNodeAndOptions); + queryPlanResult = _queryEnvironment.explainQuery(query, sqlNodeAndOptions); + String plan = queryPlanResult.getExplainPlan(); + RelNode explainRelRoot = queryPlanResult.getRelRoot(); + if (!hasTableAccess(requesterIdentity, getTableNamesFromRelRoot(explainRelRoot), requestId, requestContext)) { + return new BrokerResponseNative(QueryException.ACCESS_DENIED_ERROR); + } + return constructMultistageExplainPlan(query, plan); case SELECT: default: - queryPlan = _queryEnvironment.planQuery(query, sqlNodeAndOptions, requestId); + queryPlanResult = _queryEnvironment.planQuery(query, sqlNodeAndOptions, + requestId); break; } } catch (Exception e) { @@ -170,6 +182,27 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { return new BrokerResponseNative(QueryException.getException(QueryException.SQL_PARSING_ERROR, e)); } + QueryPlan queryPlan = queryPlanResult.getQueryPlan(); + Set<String> tableNames = getTableNamesFromRelRoot(queryPlanResult.getRelRoot()); + + // Compilation Time. This includes the time taken for parsing, compiling, create stage plans and assigning workers. + long compilationEndTimeNs = System.nanoTime(); + long compilationTimeNs = (compilationEndTimeNs - compilationStartTimeNs) + sqlNodeAndOptions.getParseTimeNs(); + updatePhaseTimingForTables(tableNames, BrokerQueryPhase.REQUEST_COMPILATION, compilationTimeNs); + + // Validate table access. + if (!hasTableAccess(requesterIdentity, tableNames, requestId, requestContext)) { + return new BrokerResponseNative(QueryException.ACCESS_DENIED_ERROR); + } + updatePhaseTimingForTables(tableNames, BrokerQueryPhase.AUTHORIZATION, System.nanoTime() - compilationEndTimeNs); + + // Validate QPS quota + if (hasExceededQPSQuota(tableNames, requestId, requestContext)) { + String errorMessage = + String.format("Request %d: %s exceeds query quota.", requestId, query); + return new BrokerResponseNative(QueryException.getException(QueryException.QUOTA_EXCEEDED_ERROR, errorMessage)); + } + boolean traceEnabled = Boolean.parseBoolean( request.has(CommonConstants.Broker.Request.TRACE) ? request.get(CommonConstants.Broker.Request.TRACE).asText() : "false"); @@ -180,6 +213,7 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { stageIdStatsMap.put(stageId, new ExecutionStatsAggregator(traceEnabled)); } + long executionStartTimeNs = System.nanoTime(); try { queryResults = _queryDispatcher.submitAndReduce(requestId, queryPlan, _mailboxService, queryTimeoutMs, sqlNodeAndOptions.getOptions(), stageIdStatsMap); @@ -190,6 +224,7 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { BrokerResponseNativeV2 brokerResponse = new BrokerResponseNativeV2(); long executionEndTimeNs = System.nanoTime(); + updatePhaseTimingForTables(tableNames, BrokerQueryPhase.QUERY_EXECUTION, executionEndTimeNs - executionStartTimeNs); // Set total query processing time long totalTimeMs = TimeUnit.NANOSECONDS.toMillis( @@ -205,11 +240,10 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { } BrokerResponseStats brokerResponseStats = new BrokerResponseStats(); - List<String> tableNames = queryPlan.getStageMetadataMap().get(entry.getKey()).getScannedTables(); - if (tableNames.size() > 0) { + if (!tableNames.isEmpty()) { //TODO: Only using first table to assign broker metrics // find a way to split metrics in case of multiple table - String rawTableName = TableNameBuilder.extractRawTableName(tableNames.get(0)); + String rawTableName = TableNameBuilder.extractRawTableName(tableNames.iterator().next()); entry.getValue().setStageLevelStats(rawTableName, brokerResponseStats, _brokerMetrics); } else { entry.getValue().setStageLevelStats(null, brokerResponseStats, null); @@ -222,6 +256,51 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { return brokerResponse; } + /** + * Validates whether the requester has access to all the tables. + */ + private boolean hasTableAccess(RequesterIdentity requesterIdentity, Set<String> tableNames, long requestId, + RequestContext requestContext) { + boolean hasAccess = _accessControlFactory.create().hasAccess(requesterIdentity, tableNames); + if (!hasAccess) { + _brokerMetrics.addMeteredGlobalValue(BrokerMeter.REQUEST_DROPPED_DUE_TO_ACCESS_ERROR, 1); + LOGGER.warn("Access denied for requestId {}", requestId); + requestContext.setErrorCode(QueryException.ACCESS_DENIED_ERROR_CODE); + return false; + } + + return true; + } + + /** + * Returns true if the QPS quota of the tables has exceeded. + */ + private boolean hasExceededQPSQuota(Set<String> tableNames, long requestId, RequestContext requestContext) { + for (String tableName : tableNames) { + if (!_queryQuotaManager.acquire(tableName)) { + LOGGER.warn("Request {}: query exceeds quota for table: {}", requestId, tableName); + requestContext.setErrorCode(QueryException.TOO_MANY_REQUESTS_ERROR_CODE); + String rawTableName = TableNameBuilder.extractRawTableName(tableName); + _brokerMetrics.addMeteredTableValue(rawTableName, BrokerMeter.QUERY_QUOTA_EXCEEDED, 1); + return true; + } + } + return false; + } + + private Set<String> getTableNamesFromRelRoot(RelNode relRoot) { + return new HashSet<>(RelOptUtil.findAllTableQualifiedNames(relRoot)); + } + + private void updatePhaseTimingForTables(Set<String> tableNames, + BrokerQueryPhase phase, long time) { + for (String tableName : tableNames) { + String rawTableName = TableNameBuilder.extractRawTableName(tableName); + _brokerMetrics.addPhaseTiming(rawTableName, phase, time); + } + } + + private BrokerResponseNative constructMultistageExplainPlan(String sql, String plan) { BrokerResponseNative brokerResponse = BrokerResponseNative.empty(); List<Object[]> rows = new ArrayList<>(); diff --git a/pinot-broker/src/test/java/org/apache/pinot/broker/broker/BasicAuthAccessControlTest.java b/pinot-broker/src/test/java/org/apache/pinot/broker/broker/BasicAuthAccessControlTest.java index f9e60f4791..a491e4f402 100644 --- a/pinot-broker/src/test/java/org/apache/pinot/broker/broker/BasicAuthAccessControlTest.java +++ b/pinot-broker/src/test/java/org/apache/pinot/broker/broker/BasicAuthAccessControlTest.java @@ -21,7 +21,9 @@ package org.apache.pinot.broker.broker; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Multimap; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import org.apache.pinot.broker.api.AccessControl; import org.apache.pinot.broker.api.HttpRequesterIdentity; import org.apache.pinot.common.request.BrokerRequest; @@ -40,13 +42,20 @@ public class BasicAuthAccessControlTest { private AccessControl _accessControl; + Set<String> _tableNames; + @BeforeClass public void setup() { Map<String, Object> config = new HashMap<>(); config.put("principals", "admin,user"); config.put("principals.admin.password", "verysecret"); config.put("principals.user.password", "secret"); - config.put("principals.user.tables", "lessImportantStuff"); + config.put("principals.user.tables", "lessImportantStuff,lesserImportantStuff,leastImportantStuff"); + + _tableNames = new HashSet<>(); + _tableNames.add("lessImportantStuff"); + _tableNames.add("lesserImportantStuff"); + _tableNames.add("leastImportantStuff"); AccessControlFactory factory = new BasicAuthAccessControlFactory(); factory.init(new PinotConfiguration(config)); @@ -56,7 +65,7 @@ public class BasicAuthAccessControlTest { @Test(expectedExceptions = IllegalArgumentException.class) public void testNullEntity() { - _accessControl.hasAccess(null, null); + _accessControl.hasAccess(null, (BrokerRequest) null); } @Test @@ -66,7 +75,7 @@ public class BasicAuthAccessControlTest { HttpRequesterIdentity identity = new HttpRequesterIdentity(); identity.setHttpHeaders(headers); - Assert.assertFalse(_accessControl.hasAccess(identity, null)); + Assert.assertFalse(_accessControl.hasAccess(identity, (BrokerRequest) null)); } @Test @@ -84,6 +93,7 @@ public class BasicAuthAccessControlTest { request.setQuerySource(source); Assert.assertTrue(_accessControl.hasAccess(identity, request)); + Assert.assertTrue(_accessControl.hasAccess(identity, _tableNames)); } @Test @@ -101,6 +111,14 @@ public class BasicAuthAccessControlTest { request.setQuerySource(source); Assert.assertFalse(_accessControl.hasAccess(identity, request)); + + Set<String> tableNames = new HashSet<>(); + tableNames.add("veryImportantStuff"); + Assert.assertFalse(_accessControl.hasAccess(identity, tableNames)); + tableNames.add("lessImportantStuff"); + Assert.assertFalse(_accessControl.hasAccess(identity, tableNames)); + tableNames.add("lesserImportantStuff"); + Assert.assertFalse(_accessControl.hasAccess(identity, tableNames)); } @Test @@ -118,6 +136,13 @@ public class BasicAuthAccessControlTest { request.setQuerySource(source); Assert.assertTrue(_accessControl.hasAccess(identity, request)); + + Set<String> tableNames = new HashSet<>(); + tableNames.add("lessImportantStuff"); + tableNames.add("veryImportantStuff"); + tableNames.add("lesserImportantStuff"); + + Assert.assertTrue(_accessControl.hasAccess(identity, tableNames)); } @Test @@ -131,6 +156,9 @@ public class BasicAuthAccessControlTest { BrokerRequest request = new BrokerRequest(); Assert.assertTrue(_accessControl.hasAccess(identity, request)); + + Set<String> tableNames = new HashSet<>(); + Assert.assertTrue(_accessControl.hasAccess(identity, tableNames)); } @Test @@ -148,5 +176,6 @@ public class BasicAuthAccessControlTest { request.setQuerySource(source); Assert.assertTrue(_accessControl.hasAccess(identity, request)); + Assert.assertTrue(_accessControl.hasAccess(identity, _tableNames)); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java index 3ee12d36f5..ff9b9150b1 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java @@ -21,6 +21,7 @@ package org.apache.pinot.query; import com.google.common.annotations.VisibleForTesting; import java.util.Arrays; import java.util.Properties; +import javax.annotation.Nullable; import org.apache.calcite.config.CalciteConnectionConfigImpl; import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.jdbc.CalciteSchema; @@ -142,13 +143,13 @@ public class QueryEnvironment { * * @param sqlQuery SQL query string. * @param sqlNodeAndOptions parsed SQL query. - * @return a dispatchable query plan + * @return QueryPlannerResult containing the dispatchable query plan and the relRoot. */ - public QueryPlan planQuery(String sqlQuery, SqlNodeAndOptions sqlNodeAndOptions, long requestId) { + public QueryPlannerResult planQuery(String sqlQuery, SqlNodeAndOptions sqlNodeAndOptions, long requestId) { try (PlannerContext plannerContext = new PlannerContext(_config, _catalogReader, _typeFactory, _hepProgram)) { plannerContext.setOptions(sqlNodeAndOptions.getOptions()); RelRoot relRoot = compileQuery(sqlNodeAndOptions.getSqlNode(), plannerContext); - return toDispatchablePlan(relRoot, plannerContext, requestId); + return new QueryPlannerResult(toDispatchablePlan(relRoot, plannerContext, requestId), null, relRoot.rel); } catch (CalciteContextException e) { throw new RuntimeException("Error composing query plan for '" + sqlQuery + "': " + e.getMessage() + "'", e); @@ -166,9 +167,9 @@ public class QueryEnvironment { * * @param sqlQuery SQL query string. * @param sqlNodeAndOptions parsed SQL query. - * @return the explained query plan. + * @return QueryPlannerResult containing the explained query plan and the relRoot. */ - public String explainQuery(String sqlQuery, SqlNodeAndOptions sqlNodeAndOptions) { + public QueryPlannerResult explainQuery(String sqlQuery, SqlNodeAndOptions sqlNodeAndOptions) { try (PlannerContext plannerContext = new PlannerContext(_config, _catalogReader, _typeFactory, _hepProgram)) { SqlExplain explain = (SqlExplain) sqlNodeAndOptions.getSqlNode(); plannerContext.setOptions(sqlNodeAndOptions.getOptions()); @@ -176,7 +177,7 @@ public class QueryEnvironment { SqlExplainFormat format = explain.getFormat() == null ? SqlExplainFormat.DOT : explain.getFormat(); SqlExplainLevel level = explain.getDetailLevel() == null ? SqlExplainLevel.DIGEST_ATTRIBUTES : explain.getDetailLevel(); - return PlannerUtils.explainPlan(relRoot.rel, format, level); + return new QueryPlannerResult(null, PlannerUtils.explainPlan(relRoot.rel, format, level), relRoot.rel); } catch (Exception e) { throw new RuntimeException("Error explain query plan for: " + sqlQuery, e); } @@ -184,12 +185,39 @@ public class QueryEnvironment { @VisibleForTesting public QueryPlan planQuery(String sqlQuery) { - return planQuery(sqlQuery, CalciteSqlParser.compileToSqlNodeAndOptions(sqlQuery), 0); + return planQuery(sqlQuery, CalciteSqlParser.compileToSqlNodeAndOptions(sqlQuery), 0).getQueryPlan(); } @VisibleForTesting public String explainQuery(String sqlQuery) { - return explainQuery(sqlQuery, CalciteSqlParser.compileToSqlNodeAndOptions(sqlQuery)); + return explainQuery(sqlQuery, CalciteSqlParser.compileToSqlNodeAndOptions(sqlQuery)).getExplainPlan(); + } + + /** + * Results of planning a query + */ + public static class QueryPlannerResult { + private QueryPlan _queryPlan; + private String _explainPlan; + private RelNode _relRoot; + + QueryPlannerResult(@Nullable QueryPlan queryPlan, @Nullable String explainPlan, RelNode relRoot) { + _queryPlan = queryPlan; + _explainPlan = explainPlan; + _relRoot = relRoot; + } + + public String getExplainPlan() { + return _explainPlan; + } + + public QueryPlan getQueryPlan() { + return _queryPlan; + } + + public RelNode getRelRoot() { + return _relRoot; + } } // -------------------------------------------------------------------------- --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org