This is an automated email from the ASF dual-hosted git repository. yashmayya pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new f1cbec74aad Adding changes for supporting RLS (#16043) f1cbec74aad is described below commit f1cbec74aad225460cd301a3aed4465ed87cc4f9 Author: 9aman <35227405+9a...@users.noreply.github.com> AuthorDate: Thu Jun 26 12:13:35 2025 +0530 Adding changes for supporting RLS (#16043) --- .../org/apache/pinot/broker/api/AccessControl.java | 13 + .../broker/BasicAuthAccessControlFactory.java | 25 ++ .../BaseSingleStageBrokerRequestHandler.java | 25 ++ .../MultiStageBrokerRequestHandler.java | 11 + .../apache/pinot/sql/parsers/CalciteSqlParser.java | 20 +- .../sql/parsers/rewriter/QueryRewriterFactory.java | 2 +- .../sql/parsers/rewriter/RlsFiltersRewriter.java | 79 +++++ .../pinot/sql/parsers/rewriter/RlsUtils.java | 44 +++ .../parsers/rewriter/QueryRewriterFactoryTest.java | 6 +- .../apache/pinot/core/auth/BasicAuthPrincipal.java | 29 +- .../org/apache/pinot/core/auth/BasicAuthUtils.java | 28 +- .../pinot/core/auth/ZkBasicAuthPrincipal.java | 10 +- .../org/apache/pinot/core/auth/BasicAuthTest.java | 104 +++++++ .../apache/pinot/core/auth/ZkBasicAuthTest.java | 80 ++--- .../tests/RowLevelSecurityIntegrationTest.java | 330 +++++++++++++++++++++ .../org/apache/pinot/query/QueryEnvironment.java | 4 + .../apache/pinot/query/runtime/QueryRunner.java | 7 +- .../plan/server/ServerPlanRequestUtils.java | 20 +- .../pinot/spi/auth/TableRowColAccessResult.java | 36 +++ .../spi/auth/TableRowColAccessResultImpl.java | 51 ++++ .../apache/pinot/spi/utils/CommonConstants.java | 2 + 21 files changed, 875 insertions(+), 51 deletions(-) diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java b/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java index 78dae4cd0ba..45a7eafec41 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/api/AccessControl.java @@ -26,6 +26,8 @@ import org.apache.pinot.spi.annotations.InterfaceStability; import org.apache.pinot.spi.auth.AuthorizationResult; import org.apache.pinot.spi.auth.BasicAuthorizationResultImpl; import org.apache.pinot.spi.auth.TableAuthorizationResult; +import org.apache.pinot.spi.auth.TableRowColAccessResult; +import org.apache.pinot.spi.auth.TableRowColAccessResultImpl; import org.apache.pinot.spi.auth.broker.RequesterIdentity; @@ -120,4 +122,15 @@ public interface AccessControl extends FineGrainedAccessControl { return hasAccess(requesterIdentity, tables) ? TableAuthorizationResult.success() : new TableAuthorizationResult(tables); } + + + /** + * Returns RLS/CLS filters for a particular table. By default, there are no RLS/CLS filters on any table. + * @param requesterIdentity requested identity + * @param table Table used in the query. Table name can be with or without tableType. + * @return {@link TableRowColAccessResult} with the result of the access control check + */ + default TableRowColAccessResult getRowColFilters(RequesterIdentity requesterIdentity, String table) { + return TableRowColAccessResultImpl.unrestricted(); + } } diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java index 129ac75f293..d37a14ab4de 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/BasicAuthAccessControlFactory.java @@ -18,13 +18,16 @@ */ package org.apache.pinot.broker.broker; +import com.google.common.base.Preconditions; import java.util.Collection; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import javax.validation.constraints.NotNull; import javax.ws.rs.NotAuthorizedException; import org.apache.pinot.broker.api.AccessControl; import org.apache.pinot.common.request.BrokerRequest; @@ -32,6 +35,8 @@ import org.apache.pinot.core.auth.BasicAuthPrincipal; import org.apache.pinot.core.auth.BasicAuthUtils; import org.apache.pinot.spi.auth.AuthorizationResult; import org.apache.pinot.spi.auth.TableAuthorizationResult; +import org.apache.pinot.spi.auth.TableRowColAccessResult; +import org.apache.pinot.spi.auth.TableRowColAccessResultImpl; import org.apache.pinot.spi.auth.broker.RequesterIdentity; import org.apache.pinot.spi.env.PinotConfiguration; @@ -131,6 +136,26 @@ public class BasicAuthAccessControlFactory extends AccessControlFactory { return new TableAuthorizationResult(failedTables); } + @Override + public TableRowColAccessResult getRowColFilters(RequesterIdentity requesterIdentity, @NotNull String table) { + Optional<BasicAuthPrincipal> principalOpt = getPrincipalOpt(requesterIdentity); + + Preconditions.checkState(principalOpt.isPresent(), "Principal is not authorized"); + Preconditions.checkState(table != null, "Table cannot be null"); + + TableRowColAccessResult tableRowColAccessResult = new TableRowColAccessResultImpl(); + BasicAuthPrincipal principal = principalOpt.get(); + + //precondition: The principal should have the table. + Preconditions.checkArgument(principal.hasTable(table), + "Principal: " + principal.getName() + " does not have access to table: " + table); + + Optional<List<String>> rlsFiltersMaybe = principal.getRLSFilters(table); + rlsFiltersMaybe.ifPresent(tableRowColAccessResult::setRLSFilters); + + return tableRowColAccessResult; + } + private Optional<BasicAuthPrincipal> getPrincipalOpt(RequesterIdentity requesterIdentity) { Collection<String> tokens = extractAuthorizationTokens(requesterIdentity); if (tokens.isEmpty()) { diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java index 2e6fbcca290..c576918e7a2 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import java.net.URI; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -98,6 +99,7 @@ import org.apache.pinot.query.routing.table.LogicalTableRouteProvider; import org.apache.pinot.query.routing.table.TableRouteProvider; import org.apache.pinot.segment.local.function.GroovyFunctionEvaluator; import org.apache.pinot.spi.auth.AuthorizationResult; +import org.apache.pinot.spi.auth.TableRowColAccessResult; import org.apache.pinot.spi.auth.broker.RequesterIdentity; import org.apache.pinot.spi.config.table.FieldConfig; import org.apache.pinot.spi.config.table.QueryConfig; @@ -122,6 +124,8 @@ import org.apache.pinot.sql.FilterKind; import org.apache.pinot.sql.parsers.CalciteSqlCompiler; import org.apache.pinot.sql.parsers.CalciteSqlParser; import org.apache.pinot.sql.parsers.SqlNodeAndOptions; +import org.apache.pinot.sql.parsers.rewriter.RlsFiltersRewriter; +import org.apache.pinot.sql.parsers.rewriter.RlsUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -434,6 +438,27 @@ public abstract class BaseSingleStageBrokerRequestHandler extends BaseBrokerRequ throwAccessDeniedError(requestId, query, requestContext, tableName, authorizationResult); } + TableRowColAccessResult rlsFilters = accessControl.getRowColFilters(requesterIdentity, tableName); + + //rewrite query + Map<String, String> queryOptions = + pinotQuery.getQueryOptions() == null ? new HashMap<>() : pinotQuery.getQueryOptions(); + + rlsFilters.getRLSFilters().ifPresent(rowFilters -> { + String combinedFilters = + rowFilters.stream().map(filter -> "( " + filter + " )").collect(Collectors.joining(" AND ")); + String rowFiltersKey = RlsUtils.buildRlsFilterKey(rawTableName); + queryOptions.put(rowFiltersKey, combinedFilters); + pinotQuery.setQueryOptions(queryOptions); + try { + CalciteSqlParser.queryRewrite(pinotQuery, RlsFiltersRewriter.class); + } catch (Exception e) { + LOGGER.error( + "Unable to apply RLS filter: {}. Row-level security filtering will be disabled for this query.", + RlsFiltersRewriter.class.getName(), e); + } + }); + // Validate QPS quota if (!_queryQuotaManager.acquireDatabase(database)) { String errorMessage = diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java index 374c9c1e8b9..6193eccc164 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java @@ -93,6 +93,7 @@ import org.apache.pinot.spi.trace.RequestContext; import org.apache.pinot.spi.trace.Tracing; import org.apache.pinot.spi.utils.CommonConstants; import org.apache.pinot.sql.parsers.SqlNodeAndOptions; +import org.apache.pinot.sql.parsers.rewriter.RlsUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.Marker; @@ -343,6 +344,16 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { if (!tableAuthorizationResult.hasAccess()) { throwTableAccessError(tableAuthorizationResult); } + AccessControl accessControl = _accessControlFactory.create(); + for (String tableName : tables) { + accessControl.getRowColFilters(requesterIdentity, tableName).getRLSFilters() + .ifPresent(rowFilters -> { + String combinedFilters = + rowFilters.stream().map(filter -> "( " + filter + " )").collect(Collectors.joining(" AND ")); + String key = RlsUtils.buildRlsFilterKey(tableName); + compiledQuery.getOptions().put(key, combinedFilters); + }); + } } } diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java index b2a43d66e07..cd893462168 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java @@ -558,7 +558,7 @@ public class CalciteSqlParser { return join; } - private static void queryRewrite(PinotQuery pinotQuery) { + public static void queryRewrite(PinotQuery pinotQuery) { for (QueryRewriter queryRewriter : QUERY_REWRITERS) { pinotQuery = queryRewriter.rewrite(pinotQuery); } @@ -566,6 +566,24 @@ public class CalciteSqlParser { validate(pinotQuery); } + /** + * Applies a specific query rewriter to the given PinotQuery and validates the result. + * This method searches for a rewriter by class name and applies it to transform the query. + * + * @param pinotQuery the query to be rewritten + * @param rewriterClass the class name of the query rewriter to apply + * @throws IllegalArgumentException if no rewriter with the specified class name is found + */ + public static void queryRewrite(PinotQuery pinotQuery, Class<? extends QueryRewriter> rewriterClass) { + QueryRewriter queryRewriter = QUERY_REWRITERS.stream() + .filter(rewriter -> rewriter.getClass().equals(rewriterClass)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Query rewriter not found: " + rewriterClass.getName())); + queryRewriter.rewrite(pinotQuery); + validate(pinotQuery); + } + + @Deprecated private static List<String> extractOptionsFromSql(String sql) { List<String> results = new ArrayList<>(); diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java index 4bd2abf6090..9117a3c6aed 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java @@ -40,7 +40,7 @@ public class QueryRewriterFactory { public static final List<String> DEFAULT_QUERY_REWRITERS_CLASS_NAMES = ImmutableList.of(CompileTimeFunctionsInvoker.class.getName(), SelectionsRewriter.class.getName(), PredicateComparisonRewriter.class.getName(), AliasApplier.class.getName(), OrdinalsUpdater.class.getName(), - NonAggregationGroupByToDistinctQueryRewriter.class.getName()); + NonAggregationGroupByToDistinctQueryRewriter.class.getName(), RlsFiltersRewriter.class.getName()); public static void init(String queryRewritersClassNamesStr) { List<String> queryRewritersClassNames = diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/RlsFiltersRewriter.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/RlsFiltersRewriter.java new file mode 100644 index 00000000000..dbd9f2bd2d0 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/RlsFiltersRewriter.java @@ -0,0 +1,79 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.sql.parsers.rewriter; + +import java.util.List; +import java.util.Map; +import org.apache.commons.collections4.MapUtils; +import org.apache.logging.log4j.util.Strings; +import org.apache.pinot.common.request.Expression; +import org.apache.pinot.common.request.PinotQuery; +import org.apache.pinot.common.utils.request.RequestUtils; +import org.apache.pinot.spi.utils.builder.TableNameBuilder; +import org.apache.pinot.sql.FilterKind; +import org.apache.pinot.sql.parsers.CalciteSqlParser; + +/** + * A query rewriter that applies Row-Level Security (RLS) filters to Pinot queries. + * + * <p>This rewriter examines query options for table-specific row filters and automatically + * applies them to the query's WHERE clause. The RLS filters are retrieved from the query + * options using the prefixed table name as the key. + * + * <p>The rewriter performs the following operations: + * <ul> + * <li>Extracts the raw table name from the query's data source</li> + * <li>Looks up RLS filters from query options using the table name</li> + * <li>Parses the filter string into an Expression object</li> + * <li>Combines the RLS filter with any existing WHERE clause using AND logic</li> + * </ul> + * + * <p>If no query options are present, no RLS filters are found for the table, or the + * filter string is empty, the query is returned unchanged. + * + * @see QueryRewriter + */ +public class RlsFiltersRewriter implements QueryRewriter { + + @Override + public PinotQuery rewrite(PinotQuery pinotQuery) { + Map<String, String> queryOptions = pinotQuery.getQueryOptions(); + if (MapUtils.isEmpty(queryOptions)) { + return pinotQuery; + } + String tableName = pinotQuery.getDataSource().getTableName(); + String rawTableName = TableNameBuilder.extractRawTableName(tableName); + String rowFilters = RlsUtils.getRlsFilterForTable(queryOptions, rawTableName); + + if (Strings.isEmpty(rowFilters)) { + return pinotQuery; + } + + Expression expression = CalciteSqlParser.compileToExpression(rowFilters); + Expression existingFilterExpression = pinotQuery.getFilterExpression(); + if (existingFilterExpression != null) { + Expression combinedFilterExpression = + RequestUtils.getFunctionExpression(FilterKind.AND.name(), List.of(expression, existingFilterExpression)); + pinotQuery.setFilterExpression(combinedFilterExpression); + } else { + pinotQuery.setFilterExpression(expression); + } + return pinotQuery; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/RlsUtils.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/RlsUtils.java new file mode 100644 index 00000000000..7d7524b28ab --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/RlsUtils.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.sql.parsers.rewriter; + +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.pinot.spi.utils.CommonConstants; + + +public class RlsUtils { + + private RlsUtils() { + } + + public static String buildRlsFilterKey(String tableName) { + return String.format("%s-%s", CommonConstants.RLS_FILTERS, tableName); + } + + public static Map<String, String> extractRlsFilters(Map<String, String> requestMetadata) { + return requestMetadata.entrySet().stream() + .filter(e -> e.getKey().startsWith(CommonConstants.RLS_FILTERS)) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static String getRlsFilterForTable(Map<String, String> queryOptions, String tableName) { + return queryOptions.get(buildRlsFilterKey(tableName)); + } +} diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactoryTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactoryTest.java index 7288c1f8435..f3dcb235e38 100644 --- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactoryTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactoryTest.java @@ -30,13 +30,14 @@ public class QueryRewriterFactoryTest { public void testQueryRewriters() { // Default behavior QueryRewriterFactory.init(null); - Assert.assertEquals(QUERY_REWRITERS.size(), 6); + Assert.assertEquals(QUERY_REWRITERS.size(), 7); Assert.assertTrue(QUERY_REWRITERS.get(0) instanceof CompileTimeFunctionsInvoker); Assert.assertTrue(QUERY_REWRITERS.get(1) instanceof SelectionsRewriter); Assert.assertTrue(QUERY_REWRITERS.get(2) instanceof PredicateComparisonRewriter); Assert.assertTrue(QUERY_REWRITERS.get(3) instanceof AliasApplier); Assert.assertTrue(QUERY_REWRITERS.get(4) instanceof OrdinalsUpdater); Assert.assertTrue(QUERY_REWRITERS.get(5) instanceof NonAggregationGroupByToDistinctQueryRewriter); + Assert.assertTrue(QUERY_REWRITERS.get(6) instanceof RlsFiltersRewriter); // Check init with other configs QueryRewriterFactory.init("org.apache.pinot.sql.parsers.rewriter.PredicateComparisonRewriter," @@ -49,12 +50,13 @@ public class QueryRewriterFactoryTest { // Revert back to default behavior QueryRewriterFactory.init(null); - Assert.assertEquals(QUERY_REWRITERS.size(), 6); + Assert.assertEquals(QUERY_REWRITERS.size(), 7); Assert.assertTrue(QUERY_REWRITERS.get(0) instanceof CompileTimeFunctionsInvoker); Assert.assertTrue(QUERY_REWRITERS.get(1) instanceof SelectionsRewriter); Assert.assertTrue(QUERY_REWRITERS.get(2) instanceof PredicateComparisonRewriter); Assert.assertTrue(QUERY_REWRITERS.get(3) instanceof AliasApplier); Assert.assertTrue(QUERY_REWRITERS.get(4) instanceof OrdinalsUpdater); Assert.assertTrue(QUERY_REWRITERS.get(5) instanceof NonAggregationGroupByToDistinctQueryRewriter); + Assert.assertTrue(QUERY_REWRITERS.get(6) instanceof RlsFiltersRewriter); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthPrincipal.java b/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthPrincipal.java index f53fd8c4953..c238d0d1260 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthPrincipal.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthPrincipal.java @@ -18,6 +18,9 @@ */ package org.apache.pinot.core.auth; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -31,14 +34,22 @@ public class BasicAuthPrincipal { private final Set<String> _tables; private final Set<String> _excludeTables; private final Set<String> _permissions; + //key: table name, val: list of RLS filters applicable for that table. + private final Map<String, List<String>> _rlsFilters; public BasicAuthPrincipal(String name, String token, Set<String> tables, Set<String> excludeTables, Set<String> permissions) { + this(name, token, tables, excludeTables, permissions, null); + } + + public BasicAuthPrincipal(String name, String token, Set<String> tables, Set<String> excludeTables, + Set<String> permissions, Map<String, List<String>> rlsFilters) { _name = name; _token = token; _tables = tables; _excludeTables = excludeTables; _permissions = permissions.stream().map(s -> s.toLowerCase()).collect(Collectors.toSet()); + _rlsFilters = rlsFilters; } public String getName() { @@ -65,13 +76,27 @@ public class BasicAuthPrincipal { return _permissions.isEmpty() || _permissions.contains(permission.toLowerCase()); } + /** + * Gets the Row-Level Security (RLS) filter configured for the given table. + * The RLS filter is applied only if the user has access to the table + * (as determined by {@link #hasTable(String)}). + * + * @param tableName The name of the table. + * @return An {@link java.util.Optional} containing the RLS filter string if configured for this principal and table, + * otherwise {@link java.util.Optional#empty()}. + */ + public Optional<List<String>> getRLSFilters(String tableName) { + return Optional.ofNullable(_rlsFilters.get(tableName)); + } + @Override public String toString() { return "BasicAuthPrincipal{" + "_name='" + _name + '\'' + ", _token='" + _token + '\'' - + ", _tables=" + _tables - + ", _permissions=" + _permissions + + ", _tables=" + _tables + '\'' + + ", _permissions=" + _permissions + '\'' + + ",_rlsFilters=" + _rlsFilters + '}'; } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthUtils.java index 94f6b6bd95c..e27816b23da 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/auth/BasicAuthUtils.java @@ -21,7 +21,9 @@ package org.apache.pinot.core.auth; import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -39,6 +41,7 @@ public final class BasicAuthUtils { private static final String TABLES = "tables"; private static final String EXCLUDE_TABLES = "excludeTables"; private static final String ALL = "*"; + private static final String RLS_FILTER = "rls"; private BasicAuthUtils() { // left blank @@ -76,8 +79,26 @@ public final class BasicAuthUtils { Set<String> excludeTables = extractSet(configuration, prefix + "." + name + "." + EXCLUDE_TABLES); Set<String> permissions = extractSet(configuration, prefix + "." + name + "." + PERMISSIONS); + // Extract RLS filters + Map<String, List<String>> tableRlsFilters = new HashMap<>(); // Changed to Map<String, List<String>> + if (!tables.isEmpty()) { + for (String tableName : tables) { + String rlsFilterKey = prefix + "." + name + "." + tableName + "." + RLS_FILTER; + String csvRlsFilters = configuration.getProperty(rlsFilterKey); // This is a CSV string + + if (StringUtils.isNotBlank(csvRlsFilters)) { + List<String> rlsFilterList = Arrays.stream(csvRlsFilters.split(",")).map(String::trim) + .filter(StringUtils::isNotBlank) // Ensure individual filters are not blank + .collect(Collectors.toList()); + if (!rlsFilterList.isEmpty()) { + tableRlsFilters.put(tableName, rlsFilterList); + } + } + } + } + return new BasicAuthPrincipal(name, org.apache.pinot.common.auth.BasicAuthUtils.toBasicAuthToken(name, password), - tables, excludeTables, permissions); + tables, excludeTables, permissions, tableRlsFilters); }).collect(Collectors.toList()); } @@ -101,9 +122,10 @@ public final class BasicAuthUtils { .orElseGet(() -> Collections.emptyList()) .stream().map(x -> x.toString()) .collect(Collectors.toSet()); + //todo: Handle RLS filters properly return new ZkBasicAuthPrincipal(name, - org.apache.pinot.common.auth.BasicAuthUtils.toBasicAuthToken(name, password), password, - component, role, tables, excludeTables, permissions); + org.apache.pinot.common.auth.BasicAuthUtils.toBasicAuthToken(name, password), password, component, role, + tables, excludeTables, permissions, Map.of()); }).collect(Collectors.toList()); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/auth/ZkBasicAuthPrincipal.java b/pinot-core/src/main/java/org/apache/pinot/core/auth/ZkBasicAuthPrincipal.java index a4ee23e035f..7bd28362ebb 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/auth/ZkBasicAuthPrincipal.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/auth/ZkBasicAuthPrincipal.java @@ -18,6 +18,8 @@ */ package org.apache.pinot.core.auth; +import java.util.List; +import java.util.Map; import java.util.Set; import org.apache.pinot.spi.config.user.ComponentType; import org.apache.pinot.spi.config.user.RoleType; @@ -32,7 +34,13 @@ public class ZkBasicAuthPrincipal extends BasicAuthPrincipal { public ZkBasicAuthPrincipal(String name, String token, String password, String component, String role, Set<String> tables, Set<String> excludeTables, Set<String> permissions) { - super(name, token, tables, excludeTables, permissions); + this(name, token, password, component, role, tables, excludeTables, permissions, null); + } + + public ZkBasicAuthPrincipal(String name, String token, String password, String component, String role, + Set<String> tables, Set<String> excludeTables, Set<String> permissions, + Map<String, List<String>> tableRLSFilters) { + super(name, token, tables, excludeTables, permissions, tableRLSFilters); _component = component; _role = role; _password = password; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/auth/BasicAuthTest.java b/pinot-core/src/test/java/org/apache/pinot/core/auth/BasicAuthTest.java index 20897c8ee7c..53225be41bc 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/auth/BasicAuthTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/auth/BasicAuthTest.java @@ -18,8 +18,15 @@ */ package org.apache.pinot.core.auth; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.apache.pinot.spi.env.PinotConfiguration; import org.testng.Assert; import org.testng.annotations.Test; @@ -62,5 +69,102 @@ public class BasicAuthTest { Assert.assertFalse(new BasicAuthPrincipal("name", "token", ImmutableSet.of("myTable"), Collections.emptySet(), ImmutableSet.of("read")) .hasPermission("write")); + + Assert.assertEquals(new BasicAuthPrincipal("name", "token", ImmutableSet.of("myTable"), Collections.emptySet(), + ImmutableSet.of("read"), Map.of("myTable", ImmutableList.of("cityID > 100"))) + .getRLSFilters("myTable"), Optional.of(ImmutableList.of("cityID > 100"))); + } + + @Test + public void testExtractBasicAuthPrincipals() { + // Test basic configuration with multiple principals + Map<String, Object> config = new HashMap<>(); + config.put("principals", "admin,user"); + config.put("principals.admin.password", "verysecret"); + config.put("principals.user.password", "secret"); + config.put("principals.user.tables", "lessImportantStuff,lesserImportantStuff,leastImportantStuff"); + config.put("principals.user.excludeTables", "excludedTable"); + config.put("principals.user.permissions", "read,write"); + config.put("principals.user.lessImportantStuff.rls", "cityID > 100,status = 'active'"); + config.put("principals.user.lesserImportantStuff.rls", "region = 'US'"); + + PinotConfiguration configuration = new PinotConfiguration(config); + List<BasicAuthPrincipal> principals = BasicAuthUtils.extractBasicAuthPrincipals(configuration, "principals"); + + Assert.assertEquals(principals.size(), 2); + + // Verify admin principal (should have no table restrictions) + BasicAuthPrincipal adminPrincipal = principals.stream() + .filter(p -> p.getName().equals("admin")) + .findFirst() + .orElse(null); + Assert.assertNotNull(adminPrincipal); + Assert.assertEquals(adminPrincipal.getName(), "admin"); + + // Verify user principal + BasicAuthPrincipal userPrincipal = principals.stream() + .filter(p -> p.getName().equals("user")) + .findFirst() + .orElse(null); + Assert.assertNotNull(userPrincipal); + Assert.assertEquals(userPrincipal.getName(), "user"); + + Set<String> expectedTables = ImmutableSet.of("lessImportantStuff", "lesserImportantStuff", "leastImportantStuff"); + expectedTables.forEach(tableName -> { + Assert.assertTrue(userPrincipal.hasTable(tableName)); + }); + + Set<String> expectedExcludeTables = ImmutableSet.of("excludedTable"); + expectedExcludeTables.forEach(tableName -> { + Assert.assertFalse(userPrincipal.hasTable(tableName)); + }); + + Set<String> expectedPermissions = ImmutableSet.of("read", "write"); + expectedPermissions.forEach(permission -> { + Assert.assertTrue(userPrincipal.hasPermission(permission)); + }); + + // Verify RLS filters + + List<String> lessImportantStuffFilters = userPrincipal.getRLSFilters("lessImportantStuff").get(); + Assert.assertNotNull(lessImportantStuffFilters); + Assert.assertEquals(lessImportantStuffFilters.size(), 2); + Assert.assertTrue(lessImportantStuffFilters.contains("cityID > 100")); + Assert.assertTrue(lessImportantStuffFilters.contains("status = 'active'")); + + List<String> lesserImportantStuffFilters = userPrincipal.getRLSFilters("lesserImportantStuff").get(); + Assert.assertNotNull(lesserImportantStuffFilters); + Assert.assertEquals(lesserImportantStuffFilters.size(), 1); + Assert.assertTrue(lesserImportantStuffFilters.contains("region = 'US'")); + + // Verify no RLS filters for leastImportantStuff (not configured) + Assert.assertTrue(userPrincipal.getRLSFilters("leastImportantStuff").isEmpty()); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "must provide " + + "principals") + public void testExtractBasicAuthPrincipalsNoPrincipals() { + Map<String, Object> config = new HashMap<>(); + PinotConfiguration configuration = new PinotConfiguration(config); + BasicAuthUtils.extractBasicAuthPrincipals(configuration, "principals"); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "must provide a " + + "password for.*") + public void testExtractBasicAuthPrincipalsNoPassword() { + Map<String, Object> config = new HashMap<>(); + config.put("principals", "admin"); + PinotConfiguration configuration = new PinotConfiguration(config); + BasicAuthUtils.extractBasicAuthPrincipals(configuration, "principals"); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = ".* is not a valid name") + public void testExtractBasicAuthPrincipalsBlankName() { + Map<String, Object> config = new HashMap<>(); + config.put("principals", "admin, ,user"); + config.put("principals.admin.password", "secret"); + config.put("principals.user.password", "secret"); + PinotConfiguration configuration = new PinotConfiguration(config); + BasicAuthUtils.extractBasicAuthPrincipals(configuration, "principals"); } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/auth/ZkBasicAuthTest.java b/pinot-core/src/test/java/org/apache/pinot/core/auth/ZkBasicAuthTest.java index 8ddecec764b..1de512e72ae 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/auth/ZkBasicAuthTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/auth/ZkBasicAuthTest.java @@ -18,8 +18,11 @@ */ package org.apache.pinot.core.auth; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import java.util.Collections; +import java.util.List; +import java.util.Map; import org.apache.pinot.spi.config.user.ComponentType; import org.apache.pinot.spi.config.user.RoleType; import org.testng.Assert; @@ -28,41 +31,46 @@ import org.testng.annotations.Test; public class ZkBasicAuthTest { - @Test - public void testBasicAuthPrincipal() { - Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), - Collections.emptySet(), ImmutableSet.of("READ")).hasTable("myTable")); - Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), - Collections.emptySet(), ImmutableSet.of("Read")).hasTable("myTable1")); - Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), - Collections.emptySet(), ImmutableSet.of("read")).hasTable("myTable1")); - Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), - Collections.emptySet(), ImmutableSet.of("read")).hasTable("myTable2")); - Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), - ImmutableSet.of("myTable3"), ImmutableSet.of("Read")).hasTable("myTable3")); - Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), - ImmutableSet.of("myTable"), ImmutableSet.of("read")).hasTable("myTable1")); - Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), Collections.emptySet(), - ImmutableSet.of("myTable"), ImmutableSet.of("read")).hasTable("myTable")); + @Test + public void testBasicAuthPrincipal() { + Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), + Collections.emptySet(), ImmutableSet.of("READ")).hasTable("myTable")); + Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), + Collections.emptySet(), ImmutableSet.of("Read")).hasTable("myTable1")); + Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), + Collections.emptySet(), ImmutableSet.of("read")).hasTable("myTable1")); + Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), + Collections.emptySet(), ImmutableSet.of("read")).hasTable("myTable2")); + Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), + ImmutableSet.of("myTable3"), ImmutableSet.of("Read")).hasTable("myTable3")); + Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable", "myTable1"), + ImmutableSet.of("myTable"), ImmutableSet.of("read")).hasTable("myTable1")); + Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), Collections.emptySet(), + ImmutableSet.of("myTable"), ImmutableSet.of("read")).hasTable("myTable")); - Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), - Collections.emptySet(), ImmutableSet.of("READ")).hasPermission("read")); - Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), - Collections.emptySet(), ImmutableSet.of("Read")).hasPermission("READ")); - Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), - Collections.emptySet(), ImmutableSet.of("read")).hasPermission("Read")); - Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", - ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), - Collections.emptySet(), ImmutableSet.of("read")).hasPermission("write")); - } + Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), + Collections.emptySet(), ImmutableSet.of("READ")).hasPermission("read")); + Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), + Collections.emptySet(), ImmutableSet.of("Read")).hasPermission("READ")); + Assert.assertTrue(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), + Collections.emptySet(), ImmutableSet.of("read")).hasPermission("Read")); + Assert.assertFalse(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), + Collections.emptySet(), ImmutableSet.of("read")).hasPermission("write")); + + Assert.assertEquals(new ZkBasicAuthPrincipal("name", "token", "password", + ComponentType.CONTROLLER.name(), RoleType.ADMIN.name(), ImmutableSet.of("myTable"), + Collections.emptySet(), ImmutableSet.of("read"), Map.of("myTable", List.of("cityID > 100"))).getRLSFilters( + "myTable").get(), ImmutableList.of("cityID > 100")); + } } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/RowLevelSecurityIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/RowLevelSecurityIntegrationTest.java new file mode 100644 index 00000000000..de24564312a --- /dev/null +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/RowLevelSecurityIntegrationTest.java @@ -0,0 +1,330 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.integration.tests; + +import com.fasterxml.jackson.databind.JsonNode; +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.commons.io.FileUtils; +import org.apache.pinot.client.Connection; +import org.apache.pinot.client.ConnectionFactory; +import org.apache.pinot.client.JsonAsyncHttpPinotClientTransportFactory; +import org.apache.pinot.controller.helix.ControllerRequestClient; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.env.PinotConfiguration; +import org.apache.pinot.util.TestUtils; +import org.junit.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.apache.pinot.integration.tests.BasicAuthTestUtils.AUTH_HEADER; +import static org.apache.pinot.integration.tests.BasicAuthTestUtils.AUTH_HEADER_USER; +import static org.apache.pinot.integration.tests.BasicAuthTestUtils.AUTH_TOKEN; +import static org.apache.pinot.integration.tests.ClusterIntegrationTestUtils.getBrokerQueryApiUrl; + + +public class RowLevelSecurityIntegrationTest extends BaseClusterIntegrationTest { + private static final String AUTH_TOKEN_USER_2 = "Basic dXNlcjI6bm90U29TZWNyZXQ"; + public static final Map<String, String> AUTH_HEADER_USER_2 = Map.of("Authorization", AUTH_TOKEN_USER_2); + private static final String DEFAULT_TABLE_NAME_2 = "mytable2"; + private static final String DEFAULT_TABLE_NAME_3 = "mytable3"; + + protected List<File> _avroFiles; + private static final Logger LOGGER = LoggerFactory.getLogger(RowLevelSecurityIntegrationTest.class); + + @Override + protected void overrideControllerConf(Map<String, Object> properties) { + properties.put("controller.segment.fetcher.auth.token", AUTH_TOKEN); + properties.put("controller.admin.access.control.factory.class", + "org.apache.pinot.controller.api.access.BasicAuthAccessControlFactory"); + properties.put("controller.admin.access.control.principals", "admin, user, user2"); + properties.put("controller.admin.access.control.principals.admin.password", "verysecret"); + properties.put("controller.admin.access.control.principals.user.password", "secret"); + properties.put("controller.admin.access.control.principals.user2.password", "notSoSecret"); + properties.put("controller.admin.access.control.principals.user.tables", "mytable, mytable2, mytable3"); + properties.put("controller.admin.access.control.principals.user.permissions", "read"); + properties.put("controller.admin.access.control.principals.user2.tables", "mytable, mytable2, mytable3"); + properties.put("controller.admin.access.control.principals.user2.permissions", "read"); + } + + @Override + protected void overrideBrokerConf(PinotConfiguration brokerConf) { + brokerConf.setProperty("pinot.broker.access.control.class", + "org.apache.pinot.broker.broker.BasicAuthAccessControlFactory"); + brokerConf.setProperty("pinot.broker.access.control.principals", "admin, user, user2"); + brokerConf.setProperty("pinot.broker.access.control.principals.admin.password", "verysecret"); + brokerConf.setProperty("pinot.broker.access.control.principals.user.password", "secret"); + brokerConf.setProperty("pinot.broker.access.control.principals.user2.password", "notSoSecret"); + brokerConf.setProperty("pinot.broker.access.control.principals.user.tables", "mytable, mytable2, mytable3"); + brokerConf.setProperty("pinot.broker.access.control.principals.user.permissions", "read"); + brokerConf.setProperty("pinot.broker.access.control.principals.user.mytable.rls", "AirlineID='19805'"); + brokerConf.setProperty("pinot.broker.access.control.principals.user.mytable3.rls", + "AirlineID='20409' OR AirTime>'300', DestStateName='Florida'"); + brokerConf.setProperty("pinot.broker.access.control.principals.user2.tables", "mytable, mytable2, mytable3"); + brokerConf.setProperty("pinot.broker.access.control.principals.user2.permissions", "read"); + brokerConf.setProperty("pinot.broker.access.control.principals.user2.mytable.rls", + "AirlineID='19805', DestStateName='California'"); + brokerConf.setProperty("pinot.broker.access.control.principals.user2.mytable2.rls", + "AirlineID='20409', DestStateName='Florida'"); + brokerConf.setProperty("pinot.broker.access.control.principals.user2.mytable3.rls", + "AirlineID='20409' OR DestStateName='California', DestStateName='Florida'"); + } + + @Override + protected void overrideServerConf(PinotConfiguration serverConf) { + serverConf.setProperty("pinot.server.segment.fetcher.auth.token", AUTH_TOKEN); + serverConf.setProperty("pinot.server.segment.uploader.auth.token", AUTH_TOKEN); + serverConf.setProperty("pinot.server.instance.auth.token", AUTH_TOKEN); + } + + @Override + public ControllerRequestClient getControllerRequestClient() { + if (_controllerRequestClient == null) { + _controllerRequestClient = + new ControllerRequestClient(_controllerRequestURLBuilder, getHttpClient(), AUTH_HEADER); + } + return _controllerRequestClient; + } + + @Override + protected Connection getPinotConnection() { + if (_pinotConnection == null) { + JsonAsyncHttpPinotClientTransportFactory factory = new JsonAsyncHttpPinotClientTransportFactory(); + factory.setHeaders(AUTH_HEADER); + + _pinotConnection = + ConnectionFactory.fromZookeeper(getZkUrl() + "/" + getHelixClusterName(), factory.buildTransport()); + } + return _pinotConnection; + } + + @BeforeClass + public void setUp() + throws Exception { + TestUtils.ensureDirectoriesExistAndEmpty(_tempDir, _segmentDir, _tarDir); + startZk(); + startController(); + startBroker(); + startServer(); + + startKafka(); + _avroFiles = unpackAvroData(_tempDir); + pushAvroIntoKafka(_avroFiles); + + // Set up a table for testing different principals. + setupTable(DEFAULT_TABLE_NAME); + setupTable(DEFAULT_TABLE_NAME_2); + setupTable(DEFAULT_TABLE_NAME_3); + + waitForAllDocsLoaded(600_000L); + } + + private void setupTable(String tableName) + throws Exception { + + Schema schema = createSchema(); + schema.setSchemaName(tableName); + addSchema(schema); + + TableConfig tableConfig = createRealtimeTableConfig(_avroFiles.get(0)); + tableConfig.setTableName(tableName); + tableConfig.getValidationConfig().setRetentionTimeUnit("DAYS"); + tableConfig.getValidationConfig().setRetentionTimeValue("100000"); + addTableConfig(tableConfig); + + waitForDocsLoaded(600_000L, true, tableConfig.getTableName()); + } + + @AfterClass + public void tearDown() + throws IOException { + LOGGER.info("Tearing down..."); + dropRealtimeTable(getTableName()); + stopServer(); + stopBroker(); + stopController(); + stopKafka(); + stopZk(); + FileUtils.deleteDirectory(_tempDir); + } + + @Test + public void testRowFiltersForSingleStageQuery() + throws Exception { + setUseMultiStageQueryEngine(false); + String query = String.format("select count(*) from %s", DEFAULT_TABLE_NAME); + String queryWithFiltersForUser1 = "select count(*)from mytable where AirlineID=19805"; + String queryWithFiltersForUser2 = + "select count(*)from mytable where AirlineID=19805 and DestStateName='California'"; + + // compare admin response with that of user + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser1, AUTH_HEADER), queryBroker(query, AUTH_HEADER_USER))); + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser2, AUTH_HEADER), queryBroker(query, AUTH_HEADER_USER_2))); + } + + @Test + public void testRowFiltersForSingleTableWithMultiStageQuery() + throws Exception { + setUseMultiStageQueryEngine(true); + String query = "select count(*), avg(ActualElapsedTime) from mytable WHERE ActualElapsedTime > " + + "(select avg(ActualElapsedTime) as avg_profit from mytable)"; + String queryWithFiltersForUser1 = "select count(*), avg(ActualElapsedTime) " + + "from mytable " + + "WHERE ActualElapsedTime > (" + + " select avg(ActualElapsedTime) as avg_profit " + + " from mytable " + + " where AirlineID = '19805' " + + " ) " + + " and AirlineID = '19805'"; + + String queryWithFiltersForUser2 = "select count(*), avg(ActualElapsedTime) " + + "from mytable " + + "WHERE ActualElapsedTime > (" + + " select avg(ActualElapsedTime) as avg_profit " + + " from mytable " + + " where AirlineID = '19805' " + + " and DestStateName = 'California'" + + " ) " + + " and AirlineID = '19805'" + + " and DestStateName = 'California'"; + + // compare admin response with that of user + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser1, AUTH_HEADER), queryBroker(query, AUTH_HEADER_USER))); + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser2, AUTH_HEADER), queryBroker(query, AUTH_HEADER_USER_2))); + } + + @Test + public void testRowFiltersForTwoTablesWithMultiStageQuery() + throws Exception { + setUseMultiStageQueryEngine(true); + String query = "select count(*), avg(ActualElapsedTime) from mytable WHERE ActualElapsedTime > 0.1 * ABS(" + + "(select avg(ActualElapsedTime) as avg_profit from mytable2))"; + String queryWithFiltersForUser1 = "select count(*), avg(ActualElapsedTime) " + + "from mytable " + + "WHERE ActualElapsedTime > 0.1 * ABS((" + + " select avg(ActualElapsedTime) as avg_profit " + + " from mytable2 " + + " )) " + + " and AirlineID = '19805'"; + String queryWithFiltersForUser2 = "SELECT COUNT(*), AVG(ActualElapsedTime)" + + " FROM mytable " + + " WHERE ActualElapsedTime > 0.1 * ABS((" + + " SELECT AVG(ActualElapsedTime) AS avg_profit" + + " FROM mytable2" + + " WHERE AirlineID = '20409'" + + " AND DestStateName = 'Florida'" + + " ))" + + " AND DestStateName = 'California'" + + " AND AirlineID='19805'"; + + // compare admin response with that of user + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser1, AUTH_HEADER), queryBroker(query, AUTH_HEADER_USER))); + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser2, AUTH_HEADER), queryBroker(query, AUTH_HEADER_USER_2))); + } + + @Test + public void testRowFiltersForTwoTablesWithComplexExpressions() + throws Exception { + + // Test for single-stage + setUseMultiStageQueryEngine(false); + + String singleStageQuery = + String.format("select AVG(CRSDepTime) as avg_dep_time, count(*) from %s", DEFAULT_TABLE_NAME_3); + + String queryWithFiltersForUser1 = + "select AVG(CRSDepTime) as avg_dep_time, count(*) from mytable3 where (AirlineID='20409' OR AirTime>'300') " + + "AND " + + "(DestStateName='Florida')"; + + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser1, AUTH_HEADER), + queryBroker(singleStageQuery, AUTH_HEADER_USER))); + + // Test for multi-stage + setUseMultiStageQueryEngine(true); + String multiStageQuery = "select count(*), avg(ActualElapsedTime) from mytable WHERE ActualElapsedTime > 0.1 * ABS(" + + "(select avg(ActualElapsedTime) as avg_profit from mytable3))"; + String queryWithFiltersForUser2 = "SELECT COUNT(*), AVG(ActualElapsedTime)" + + " FROM mytable " + + " WHERE ActualElapsedTime > 0.1 * ABS((" + + " SELECT AVG(ActualElapsedTime) AS avg_profit" + + " FROM mytable3" + + " WHERE (AirlineID = '20409'" + + " OR DestStateName = 'California')" + + " AND DestStateName = 'Florida'" + + " ))" + + " AND DestStateName = 'California'" + + " AND AirlineID='19805'"; + + // compare admin response with that of user + + Assert.assertTrue( + compareRows(queryBroker(queryWithFiltersForUser2, AUTH_HEADER), + queryBroker(multiStageQuery, AUTH_HEADER_USER_2))); + } + + private JsonNode queryBroker(String query, Map<String, String> headers) + throws Exception { + JsonNode response = + postQuery(query, getBrokerQueryApiUrl(getBrokerBaseApiUrl(), useMultiStageQueryEngine()), headers, + getExtraQueryProperties()); + return response; + } + + private boolean compareRows(JsonNode expectedResponse, JsonNode response) { + JsonNode responseRow = response.get("resultTable").get("rows").get(0); + JsonNode expectedRow = expectedResponse.get("resultTable").get("rows").get(0); + + // Compare each column + for (int i = 0; i < responseRow.size(); i++) { + JsonNode responseValue = responseRow.get(i); + JsonNode expectedValue = expectedRow.get(i); + + if (responseValue.isNumber() && expectedValue.isNumber()) { + // For numeric values, use appropriate comparison + if (responseValue.isIntegralNumber() && expectedValue.isIntegralNumber()) { + // Integer comparison + if (responseValue.asLong() != expectedValue.asLong()) { + return false; + } + } else { + // Floating point comparison with delta + double delta = Math.abs(responseValue.asDouble() - expectedValue.asDouble()); + if (delta > 0.001) { + return false; + } + } + } + } + return true; + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java index 6899dd08069..f8adf7fb159 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java @@ -742,6 +742,10 @@ public class QueryEnvironment { return _sqlNodeAndOptions.getSqlNode().getKind().equals(SqlKind.EXPLAIN); } + public PlannerContext getPlannerContext() { + return _plannerContext; + } + /// Explain the query plan. /// The original query must be an EXPLAIN query and way it will be explained depends on the options of the EXPLAIN /// query and the [QueryEnvironment.Config] used to create the [QueryEnvironment] that compiled this query. diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java index 85e0a29ad54..1706187d4a7 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java @@ -86,6 +86,7 @@ import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner.JoinOver import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner.WindowOverFlowMode; import org.apache.pinot.spi.utils.CommonConstants.Query.Request.MetadataKeys; import org.apache.pinot.spi.utils.CommonConstants.Server; +import org.apache.pinot.sql.parsers.rewriter.RlsUtils; import org.apache.pinot.tsdb.planner.TimeSeriesPlanConstants.WorkerRequestMetadataKeys; import org.apache.pinot.tsdb.planner.TimeSeriesPlanConstants.WorkerResponseMetadataKeys; import org.apache.pinot.tsdb.spi.PinotTimeSeriesConfiguration; @@ -285,8 +286,10 @@ public class QueryRunner { workerMetadata, pipelineBreakerResult, parentContext, _sendStats.getAsBoolean()); OpChain opChain; if (workerMetadata.isLeafStageWorker()) { + Map<String, String> rlsFilters = RlsUtils.extractRlsFilters(requestMetadata); opChain = - ServerPlanRequestUtils.compileLeafStage(executionContext, stagePlan, _leafQueryExecutor, _executorService); + ServerPlanRequestUtils.compileLeafStage(executionContext, stagePlan, _leafQueryExecutor, _executorService, + rlsFilters); } else { opChain = PlanNodeToOpChain.convert(stagePlan.getRootNode(), executionContext); } @@ -528,7 +531,7 @@ public class QueryRunner { OpChain opChain = ServerPlanRequestUtils.compileLeafStage(executionContext, stagePlan, _leafQueryExecutor, _executorService, - leafNodesConsumer, true); + leafNodesConsumer, true, Map.of()); opChain.close(); // probably unnecessary, but formally needed PlanNode rootNode = substituteNode(stagePlan.getRootNode(), leafNodes); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestUtils.java index 807a9a9fd63..6e5d1aac01f 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestUtils.java @@ -29,6 +29,7 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.function.BiConsumer; import javax.annotation.Nullable; +import org.apache.commons.collections.MapUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.metrics.ServerMetrics; import org.apache.pinot.common.request.BrokerRequest; @@ -67,6 +68,7 @@ import org.apache.pinot.sql.parsers.rewriter.NonAggregationGroupByToDistinctQuer import org.apache.pinot.sql.parsers.rewriter.PredicateComparisonRewriter; import org.apache.pinot.sql.parsers.rewriter.QueryRewriter; import org.apache.pinot.sql.parsers.rewriter.QueryRewriterFactory; +import org.apache.pinot.sql.parsers.rewriter.RlsFiltersRewriter; public class ServerPlanRequestUtils { @@ -76,11 +78,18 @@ public class ServerPlanRequestUtils { private static final int DEFAULT_LEAF_NODE_LIMIT = Integer.MAX_VALUE; private static final List<String> QUERY_REWRITERS_CLASS_NAMES = ImmutableList.of(PredicateComparisonRewriter.class.getName(), - NonAggregationGroupByToDistinctQueryRewriter.class.getName()); + NonAggregationGroupByToDistinctQueryRewriter.class.getName(), RlsFiltersRewriter.class.getName()); private static final List<QueryRewriter> QUERY_REWRITERS = new ArrayList<>(QueryRewriterFactory.getQueryRewriters(QUERY_REWRITERS_CLASS_NAMES)); private static final QueryOptimizer QUERY_OPTIMIZER = new QueryOptimizer(); + public static OpChain compileLeafStage(OpChainExecutionContext executionContext, StagePlan stagePlan, + QueryExecutor leafQueryExecutor, ExecutorService executorService, Map<String, String> rowFilters) { + return compileLeafStage(executionContext, stagePlan, leafQueryExecutor, executorService, + (planNode, multiStageOperator) -> { + }, false, rowFilters); + } + public static OpChain compileLeafStage( OpChainExecutionContext executionContext, StagePlan stagePlan, @@ -88,7 +97,7 @@ public class ServerPlanRequestUtils { ExecutorService executorService) { return compileLeafStage(executionContext, stagePlan, leafQueryExecutor, executorService, (planNode, multiStageOperator) -> { - }, false); + }, false, null); } /** @@ -104,7 +113,7 @@ public class ServerPlanRequestUtils { QueryExecutor leafQueryExecutor, ExecutorService executorService, BiConsumer<PlanNode, MultiStageOperator> relationConsumer, - boolean explain) { + boolean explain, @Nullable Map<String, String> rowFilters) { long queryArrivalTimeMs = System.currentTimeMillis(); ServerPlanRequestContext serverContext = new ServerPlanRequestContext(stagePlan, leafQueryExecutor, executorService, @@ -114,6 +123,11 @@ public class ServerPlanRequestUtils { // 2. Convert PinotQuery into InstanceRequest list (one for each physical table) PinotQuery pinotQuery = serverContext.getPinotQuery(); pinotQuery.setExplain(explain); + + if (MapUtils.isNotEmpty(rowFilters)) { + pinotQuery.setQueryOptions(rowFilters); + } + List<InstanceRequest> instanceRequests; if (executionContext.getWorkerMetadata().getLogicalTableSegmentsMap() != null) { instanceRequests = constructLogicalTableServerQueryRequests(executionContext, pinotQuery, diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/auth/TableRowColAccessResult.java b/pinot-spi/src/main/java/org/apache/pinot/spi/auth/TableRowColAccessResult.java new file mode 100644 index 00000000000..4a8fa88657c --- /dev/null +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/auth/TableRowColAccessResult.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.spi.auth; + +import java.util.List; +import java.util.Optional; + + +/** + * This interfaces carries the RLS/CLS filter for a particular table + */ +public interface TableRowColAccessResult { + /** + * Returns the RLS filters associated with a particular table. RLS filters are defined as a list. + * @return optional of the RLS filters. Empty optional if there are no RLS filters defined on this table + */ + Optional<List<String>> getRLSFilters(); + + void setRLSFilters(List<String> rlsFilters); +} diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/auth/TableRowColAccessResultImpl.java b/pinot-spi/src/main/java/org/apache/pinot/spi/auth/TableRowColAccessResultImpl.java new file mode 100644 index 00000000000..5637b90c5e3 --- /dev/null +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/auth/TableRowColAccessResultImpl.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.spi.auth; + +import java.util.List; +import java.util.Optional; + + +public class TableRowColAccessResultImpl implements TableRowColAccessResult { + + private static final TableRowColAccessResult UNRESTRICTED = new TableRowColAccessResultImpl(); + + private List<String> _rlsFilters; + + public TableRowColAccessResultImpl() { + } + + public TableRowColAccessResultImpl(List<String> rlsFilters) { + _rlsFilters = rlsFilters; + } + + @Override + public void setRLSFilters(List<String> rlsFilters) { + _rlsFilters = rlsFilters; + } + + @Override + public Optional<List<String>> getRLSFilters() { + return Optional.ofNullable(_rlsFilters); + } + + public static TableRowColAccessResult unrestricted() { + return UNRESTRICTED; + } +} diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java index d493deb9cb2..26305c5698b 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java @@ -74,6 +74,8 @@ public class CommonConstants { public static final String JFR = "pinot.jfr"; + public static final String RLS_FILTERS = "rlsFilters"; + /** * The state of the consumer for a given segment */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org