This is an automated email from the ASF dual-hosted git repository. kturner pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/accumulo-access.git
The following commit(s) were added to refs/heads/main by this push: new cb6a2d4 Improves performance by avoiding parse tree creation (#31) cb6a2d4 is described below commit cb6a2d4171ae7c1068b0cf89343edaa46726fbdf Author: Keith Turner <ktur...@apache.org> AuthorDate: Wed Jan 24 14:40:00 2024 -0500 Improves performance by avoiding parse tree creation (#31) Improves the performance by evaluating and parsing an expression at the same time and avoiding creation of a parse tree. Also some other steps were taken to avoid other object allocations. Seeing really good performance with these changes when running the benchmark. --- .../apache/accumulo/access/AccessEvaluator.java | 3 +- .../accumulo/access/AccessEvaluatorImpl.java | 38 +++--- .../apache/accumulo/access/AccessExpression.java | 129 +++++++++++++++----- .../accumulo/access/AccessExpressionImpl.java | 130 ++++++++++---------- .../java/org/apache/accumulo/access/AeNode.java | 134 +++------------------ .../org/apache/accumulo/access/Authorizations.java | 7 +- .../org/apache/accumulo/access/BytesWrapper.java | 42 ++++--- .../access/{Parser.java => Normalizer.java} | 65 +++++----- .../apache/accumulo/access/ParserEvaluator.java | 130 ++++++++++++++++++++ .../java/org/apache/accumulo/access/Tokenizer.java | 12 +- .../accumulo/access/AccessEvaluatorTest.java | 55 +++++++-- .../accumulo/access/AccessExpressionBenchmark.java | 27 +---- .../accumulo/access/AccessExpressionTest.java | 55 +++++++-- 13 files changed, 495 insertions(+), 332 deletions(-) diff --git a/src/main/java/org/apache/accumulo/access/AccessEvaluator.java b/src/main/java/org/apache/accumulo/access/AccessEvaluator.java index cc8cf68..5fe7619 100644 --- a/src/main/java/org/apache/accumulo/access/AccessEvaluator.java +++ b/src/main/java/org/apache/accumulo/access/AccessEvaluator.java @@ -67,8 +67,7 @@ public interface AccessEvaluator { boolean canAccess(byte[] accessExpression) throws IllegalAccessExpressionException; /** - * @param accessExpression a validated and parsed access expression. The implementation of this - * method may be able to reuse the internal parse tree and avoid re-parsing. + * @param accessExpression previously validated access expression * @return true if the expression is visible using the authorizations supplied at creation, false * otherwise */ diff --git a/src/main/java/org/apache/accumulo/access/AccessEvaluatorImpl.java b/src/main/java/org/apache/accumulo/access/AccessEvaluatorImpl.java index 360ff55..5619701 100644 --- a/src/main/java/org/apache/accumulo/access/AccessEvaluatorImpl.java +++ b/src/main/java/org/apache/accumulo/access/AccessEvaluatorImpl.java @@ -37,6 +37,13 @@ import java.util.stream.Stream; class AccessEvaluatorImpl implements AccessEvaluator { private final Collection<Predicate<BytesWrapper>> authorizedPredicates; + private static final byte[] EMPTY = new byte[0]; + + private final ThreadLocal<BytesWrapper> lookupWrappers = + ThreadLocal.withInitial(() -> new BytesWrapper(EMPTY)); + private final ThreadLocal<Tokenizer> tokenizers = + ThreadLocal.withInitial(() -> new Tokenizer(EMPTY)); + private AccessEvaluatorImpl(Authorizer authorizationChecker) { this.authorizedPredicates = List.of(auth -> authorizationChecker.isAuthorized(unescape(auth))); } @@ -126,30 +133,31 @@ class AccessEvaluatorImpl implements AccessEvaluator { } @Override - public boolean canAccess(String expression) throws IllegalAccessExpressionException { - return evaluate(new AccessExpressionImpl(expression)); + public boolean canAccess(AccessExpression expression) { + return canAccess(expression.getExpression()); } @Override - public boolean canAccess(byte[] expression) throws IllegalAccessExpressionException { - return evaluate(new AccessExpressionImpl(expression)); + public boolean canAccess(String expression) throws IllegalAccessExpressionException { + return evaluate(expression.getBytes(UTF_8)); } @Override - public boolean canAccess(AccessExpression expression) throws IllegalAccessExpressionException { - if (expression instanceof AccessExpressionImpl) { - return evaluate((AccessExpressionImpl) expression); - } else { - return canAccess(expression.getExpression()); - } + public boolean canAccess(byte[] expression) throws IllegalAccessExpressionException { + return evaluate(expression); } - public boolean evaluate(AccessExpressionImpl accessExpression) - throws IllegalAccessExpressionException { - // The AccessEvaluator computes a trie from the given Authorizations, that AccessExpressions can - // be evaluated against. + boolean evaluate(byte[] accessExpression) throws IllegalAccessExpressionException { + var bytesWrapper = lookupWrappers.get(); + for (var auths : authorizedPredicates) { - if (!accessExpression.aeNode.canAccess(auths)) { + var tokenizer = tokenizers.get(); + tokenizer.reset(accessExpression); + Predicate<Tokenizer.AuthorizationToken> atp = authToken -> { + bytesWrapper.set(authToken.data, authToken.start, authToken.len); + return auths.test(bytesWrapper); + }; + if (!ParserEvaluator.parseAccessExpression(tokenizer, atp, authToken -> true)) { return false; } } diff --git a/src/main/java/org/apache/accumulo/access/AccessExpression.java b/src/main/java/org/apache/accumulo/access/AccessExpression.java index c940e7c..c2fc000 100644 --- a/src/main/java/org/apache/accumulo/access/AccessExpression.java +++ b/src/main/java/org/apache/accumulo/access/AccessExpression.java @@ -18,24 +18,30 @@ */ package org.apache.accumulo.access; -import java.util.Arrays; - /** - * An opaque type that contains a parsed access expression. When this type is constructed with - * {@link #of(String)} and then used with {@link AccessEvaluator#canAccess(AccessExpression)} it can - * be more efficient and avoid re-parsing the expression. + * This class offers the ability to validate, build, and normalize access expressions. An instance + * of this class should wrap an immutable, validated access expression. If passing access + * expressions as arguments in code, consider using this type instead of a String. The advantage of + * passing this type over a String is that its known to be a valid expression. * - * Below is an example of using this API. + * <p> + * >Below is an example of using this API. * * <pre> * {@code + * // The following authorization does not need quoting, so the return value is the same as the + * // input. * var auth1 = AccessExpression.quote("CAT"); + * // The following two authorizations need quoting and the return values will be quoted. * var auth2 = AccessExpression.quote("🦕"); * var auth3 = AccessExpression.quote("🦖"); - * var visExp = AccessExpression - * .of("(" + auth1 + "&" + auth3 + ")|(" + auth1 + "&" + auth2 + "&" + auth1 + ")"); + * var exp = "(" + auth1 + "&" + auth3 + ")|(" + auth1 + "&" + auth2 + "&" + auth1 + ")"; + * // Validate the expression, but do not normalize it + * var visExp = AccessExpression.of(exp); * System.out.println(visExp.getExpression()); - * System.out.println(visExp.normalize()); + * // Validate and normalize the expression. + * System.out.println(AccessExpression.of(exp, true).getExpression()); + * // Print the unique authorization in the expression * System.out.println(visExp.getAuthorizations()); * } * </pre> @@ -48,6 +54,15 @@ import java.util.Arrays; * [🦖, CAT, 🦕] * </pre> * + * The following code will throw an {@link IllegalAccessExpressionException} because the expression + * is not valid. + * + * <pre> + * {@code + * AccessExpression.validate("A&B|C"); + * } + * </pre> + * * @see <a href="https://github.com/apache/accumulo-access">Accumulo Access Documentation</a> * @since 1.0.0 */ @@ -59,47 +74,103 @@ public interface AccessExpression { String getExpression(); /** - * Deduplicate, sort, and flatten expressions. - * + * @return the unique set of authorizations that occur in the expression. For example, for the + * expression {@code (A&B)|(A&C)|(A&D)}, this method would return {@code [A,B,C,D]}. + */ + Authorizations getAuthorizations(); + + /** + * This is equivalent to calling {@code AccessExpression.of(expression, false);} + */ + static AccessExpression of(String expression) throws IllegalAccessExpressionException { + return new AccessExpressionImpl(expression, false); + } + + /** * <p> - * As an example of flattening, the expression {@code A&(B&C)} can be flattened to {@code A&B&C}. + * Validates an access expression and creates an immutable AccessExpression object. * * <p> - * As an example of sorting, the expression {@code (Z&Y)|(C&B)} can be sorted to - * {@code (B&C)|(Y&Z)} + * When the {@code normalize} parameter is true, then will deduplicate, sort, flatten, and remove + * unneeded parens or quotes in the expressions. Normalization is done in addition to validation. + * The following list gives examples of what each normalization step does. * - * <p> - * As an example of deduplication, the expression {@code X&Y&X} is equivalent to {@code X&Y} + * <ul> + * <li>As an example of flattening, the expression {@code A&(B&C)} flattens to {@code A&B&C}.</li> + * <li>As an example of sorting, the expression {@code (Z&Y)|(C&B)} sorts to + * {@code (B&C)|(Y&Z)}</li> + * <li>As an example of deduplication, the expression {@code X&Y&X} normalizes to {@code X&Y}</li> + * <li>As an example of unneed quotes, the expression {@code "ABC"&"XYZ"} normalizes to + * {@code ABC&XYZ}</li> + * <li>As an example of unneed parens, the expression {@code (((ABC)|(XYZ)))} normalizes to * + * {@code ABC|XYZ}</li> + * </ul> * - * @return A normalized version of the access expression that removes duplicates and orders the - * expression in a consistent way. + * @param expression an access expression + * @param normalize If true then the expression will be normalized, if false the expression will + * only be validated. Normalization is expensive so only use when needed. If repeatedly + * normalizing expressions, consider using a cache that maps un-normalized expressions to + * normalized ones. Since the normalization process is deterministic, the computation can + * be cached. + * @throws IllegalAccessExpressionException when the expression is not valid. */ - String normalize(); + static AccessExpression of(String expression, boolean normalize) + throws IllegalAccessExpressionException { + return new AccessExpressionImpl(expression, normalize); + } /** - * @return the unique set of authorizations that occur in the expression. For example, for the - * expression {@code (A&B)|(A&C)|(A&D)}, this method would return {@code [A,B,C,D]}. + * <p> + * This is equivalent to calling {@code AccessExpression.of(expression, false);} */ - Authorizations getAuthorizations(); - - static AccessExpression of(String expression) throws IllegalAccessExpressionException { - return new AccessExpressionImpl(expression); + static AccessExpression of(byte[] expression) throws IllegalAccessExpressionException { + return new AccessExpressionImpl(expression, false); } /** - * @param expression is expected to be encoded using UTF-8 + * <p> + * Validates an access expression and creates an immutable AccessExpression object. + * + * <p> + * If only validation is needed, then call {@link #validate(byte[])} because it will avoid copying + * the expression like this method does. This method must copy the byte array into a String + * inorder to create an immutable AccessExpression. + * + * @see #of(String, boolean) for information about normlization. + * @param expression an access expression that is expected to be encoded using UTF-8 + * @param normalize If true then the expression will be normalized, if false the expression will + * only be validated. Normalization is expensive so only use when needed. + * @throws IllegalAccessExpressionException when the expression is not valid. */ - static AccessExpression of(byte[] expression) throws IllegalAccessExpressionException { - return new AccessExpressionImpl(Arrays.copyOf(expression, expression.length)); + static AccessExpression of(byte[] expression, boolean normalize) + throws IllegalAccessExpressionException { + return new AccessExpressionImpl(expression, normalize); } /** - * @return an empty AccessExpression. + * @return an empty AccessExpression that is immutable. */ static AccessExpression of() { return AccessExpressionImpl.EMPTY; } + /** + * Quickly validates that an access expression is properly formed. + * + * @param expression a potential access expression that is expected to be encoded using UTF-8 + * @throws IllegalAccessExpressionException if the given expression is not valid + */ + static void validate(byte[] expression) throws IllegalAccessExpressionException { + AccessExpressionImpl.validate(expression); + } + + /** + * @see #validate(byte[]) + */ + static void validate(String expression) throws IllegalAccessExpressionException { + AccessExpressionImpl.validate(expression); + } + /** * Authorizations occurring in an access expression can only contain the characters listed in the * <a href= diff --git a/src/main/java/org/apache/accumulo/access/AccessExpressionImpl.java b/src/main/java/org/apache/accumulo/access/AccessExpressionImpl.java index 3f82004..9791643 100644 --- a/src/main/java/org/apache/accumulo/access/AccessExpressionImpl.java +++ b/src/main/java/org/apache/accumulo/access/AccessExpressionImpl.java @@ -20,99 +20,77 @@ package org.apache.accumulo.access; import static java.nio.charset.StandardCharsets.UTF_8; -import java.util.Arrays; import java.util.HashSet; -import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; class AccessExpressionImpl implements AccessExpression { - private final byte[] expression; + public static final AccessExpression EMPTY = new AccessExpressionImpl("", false); - final AeNode aeNode; + private final String expression; - private final AtomicReference<String> expressionString = new AtomicReference<>(null); - - @Override - public String getExpression() { - var expStr = expressionString.get(); - if (expStr != null) { - return expStr; + AccessExpressionImpl(String expression, boolean normalize) { + if (normalize) { + // validate and normalize expression + this.expression = normalize(expression); + } else { + validate(expression); + this.expression = expression; } - - return expressionString.updateAndGet(es -> es == null ? new String(expression, UTF_8) : es); } - // must create this after creating EMPTY_NODE - static final AccessExpression EMPTY = new AccessExpressionImpl(""); + AccessExpressionImpl(byte[] expression, boolean normalize) { + if (normalize) { + // validate and normalize expression + this.expression = normalize(expression); + } else { + validate(expression); + this.expression = new String(expression, UTF_8); + } + } @Override - public String normalize() { - StringBuilder builder = new StringBuilder(); - aeNode.normalize().stringify(builder, false); - return builder.toString(); + public String getExpression() { + return expression; } @Override - public Authorizations getAuthorizations() { - HashSet<String> auths = new HashSet<>(); - aeNode.getAuthorizations(auths::add); - return Authorizations.of(auths); + public String toString() { + return expression; } - /** - * Creates an AccessExpression. - * - * @param expression An expression of the rights needed to see specific data. The expression - * syntax is defined within the <a href= - * "https://github.com/apache/accumulo-access/blob/main/SPECIFICATION.md">specification - * doc</a> - */ - AccessExpressionImpl(String expression) { - this(expression.getBytes(UTF_8)); - expressionString.set(expression); - } + @Override + public boolean equals(Object o) { + if (o instanceof AccessExpressionImpl) { + return ((AccessExpressionImpl) o).expression.equals(expression); + } - /** - * Creates an AccessExpression from a string already encoded in UTF-8 bytes. - * - * @param expression AccessExpression, encoded as UTF-8 bytes - * @see #AccessExpressionImpl(String) - */ - AccessExpressionImpl(byte[] expression) { - this.expression = expression; - aeNode = Parser.parseAccessExpression(expression); + return false; } @Override - public String toString() { - return getExpression(); + public int hashCode() { + return expression.hashCode(); } - /** - * See {@link #equals(AccessExpressionImpl)} - */ @Override - public boolean equals(Object obj) { - if (obj instanceof AccessExpressionImpl) { - return equals((AccessExpressionImpl) obj); - } - return false; + public Authorizations getAuthorizations() { + return AccessExpressionImpl.getAuthorizations(expression); } - /** - * Compares two AccessExpressions for string equivalence, not as a meaningful comparison of terms - * and conditions. - * - * @param otherLe other AccessExpression - * @return true if this AccessExpression equals the other via string comparison - */ - boolean equals(AccessExpressionImpl otherLe) { - return Arrays.equals(expression, otherLe.expression); + static Authorizations getAuthorizations(byte[] expression) { + HashSet<String> auths = new HashSet<>(); + Tokenizer tokenizer = new Tokenizer(expression); + Predicate<Tokenizer.AuthorizationToken> atp = authToken -> { + auths.add(new String(authToken.data, authToken.start, authToken.len, UTF_8)); + return true; + }; + ParserEvaluator.parseAccessExpression(tokenizer, atp, atp); + return Authorizations.of(auths); } - @Override - public int hashCode() { - return Arrays.hashCode(expression); + static Authorizations getAuthorizations(String expression) { + return getAuthorizations(expression.getBytes(UTF_8)); } static String quote(String term) { @@ -142,4 +120,24 @@ class AccessExpressionImpl implements AccessExpression { return AccessEvaluatorImpl.escape(term, true); } + + static void validate(byte[] expression) throws IllegalAccessExpressionException { + Tokenizer tokenizer = new Tokenizer(expression); + Predicate<Tokenizer.AuthorizationToken> atp = authToken -> true; + ParserEvaluator.parseAccessExpression(tokenizer, atp, atp); + } + + static void validate(String expression) throws IllegalAccessExpressionException { + validate(expression.getBytes(UTF_8)); + } + + static String normalize(String expression) throws IllegalAccessExpressionException { + Tokenizer tokenizer = new Tokenizer(expression.getBytes(UTF_8)); + return Normalizer.normalize(tokenizer); + } + + static String normalize(byte[] expression) throws IllegalAccessExpressionException { + Tokenizer tokenizer = new Tokenizer(expression); + return Normalizer.normalize(tokenizer); + } } diff --git a/src/main/java/org/apache/accumulo/access/AeNode.java b/src/main/java/org/apache/accumulo/access/AeNode.java index dc04f6f..b9eb446 100644 --- a/src/main/java/org/apache/accumulo/access/AeNode.java +++ b/src/main/java/org/apache/accumulo/access/AeNode.java @@ -23,8 +23,6 @@ import static org.apache.accumulo.access.ByteUtils.OR_OPERATOR; import java.util.List; import java.util.TreeSet; -import java.util.function.Consumer; -import java.util.function.Predicate; /** * Contains the code for an Access Expression represented as a parse tree and all the operations on @@ -32,10 +30,6 @@ import java.util.function.Predicate; */ abstract class AeNode implements Comparable<AeNode> { - abstract boolean canAccess(Predicate<BytesWrapper> authorizedPredicate); - - abstract void getAuthorizations(Consumer<String> authConsumer); - abstract void stringify(StringBuilder builder, boolean addParens); abstract AeNode normalize(); @@ -58,13 +52,6 @@ abstract class AeNode implements Comparable<AeNode> { } private static class EmptyNode extends AeNode { - @Override - boolean canAccess(Predicate<BytesWrapper> authorizedPredicate) { - return true; - } - - @Override - void getAuthorizations(Consumer<String> authConsumer) {} @Override void stringify(StringBuilder builder, boolean addParens) { @@ -94,16 +81,6 @@ abstract class AeNode implements Comparable<AeNode> { authInExpression = new BytesWrapper(auth.data, auth.start, auth.len); } - @Override - boolean canAccess(Predicate<BytesWrapper> authorizedPredicate) { - return authorizedPredicate.test(authInExpression); - } - - @Override - void getAuthorizations(Consumer<String> authConsumer) { - authConsumer.accept(AccessEvaluatorImpl.unescape(authInExpression)); - } - @Override void stringify(StringBuilder builder, boolean addParens) { boolean needsQuotes = false; @@ -144,20 +121,28 @@ abstract class AeNode implements Comparable<AeNode> { } } - private static abstract class MultiNode extends AeNode { + private static class MultiNode extends AeNode { protected final List<AeNode> children; - private MultiNode(List<AeNode> children) { + private final byte operator; + + private MultiNode(byte operator, List<AeNode> children) { + this.operator = operator; this.children = children; } @Override - void getAuthorizations(Consumer<String> authConsumer) { - children.forEach(aeNode -> aeNode.getAuthorizations(authConsumer)); + int ordinal() { + switch (operator) { + case AND_OPERATOR: + return 3; + case OR_OPERATOR: + return 2; + default: + throw new IllegalStateException(); + } } - abstract char operator(); - @Override void stringify(StringBuilder builder, boolean addParens) { if (addParens) { @@ -167,7 +152,7 @@ abstract class AeNode implements Comparable<AeNode> { var iter = children.iterator(); iter.next().stringify(builder, true); iter.forEachRemaining(aeNode -> { - builder.append(operator()); + builder.append((char) operator); aeNode.stringify(builder, true); }); @@ -193,34 +178,11 @@ abstract class AeNode implements Comparable<AeNode> { } return cmp; } - } - - private static class AndNode extends MultiNode { - - private AndNode(List<AeNode> children) { - super(children); - } - - @Override - char operator() { - return '&'; - } - - @Override - boolean canAccess(Predicate<BytesWrapper> authorizedPredicate) { - for (var child : children) { - if (!child.canAccess(authorizedPredicate)) { - return false; - } - } - - return true; - } void flatten(TreeSet<AeNode> nodes) { for (var child : children) { - if (child instanceof AndNode) { - ((AndNode) child).flatten(nodes); + if (child instanceof MultiNode && ((MultiNode) child).operator == operator) { + ((MultiNode) child).flatten(nodes); } else { nodes.add(child.normalize()); } @@ -234,67 +196,10 @@ abstract class AeNode implements Comparable<AeNode> { if (flattened.size() == 1) { return flattened.iterator().next(); } else { - return new AndNode(List.copyOf(flattened)); + return new MultiNode(operator, List.copyOf(flattened)); } } - @Override - int ordinal() { - return 3; - } - } - - private static class OrNode extends MultiNode { - - private OrNode(List<AeNode> children) { - super(children); - } - - @Override - char operator() { - return '|'; - } - - @Override - boolean canAccess(Predicate<BytesWrapper> authorizedPredicate) { - for (var child : children) { - if (child.canAccess(authorizedPredicate)) { - return true; - } - } - - return false; - } - - void flatten(TreeSet<AeNode> nodes) { - for (var child : children) { - if (child instanceof OrNode) { - ((OrNode) child).flatten(nodes); - } else { - nodes.add(child.normalize()); - } - } - } - - @Override - AeNode normalize() { - var flattened = new TreeSet<AeNode>(); - flatten(flattened); - if (flattened.size() == 1) { - return flattened.iterator().next(); - } else { - return new OrNode(List.copyOf(flattened)); - } - } - - @Override - int ordinal() { - return 2; - } - } - - static AeNode of() { - return new EmptyNode(); } static AeNode of(Tokenizer.AuthorizationToken auth) { @@ -304,9 +209,8 @@ abstract class AeNode implements Comparable<AeNode> { static AeNode of(byte operator, List<AeNode> children) { switch (operator) { case AND_OPERATOR: - return new AndNode(children); case OR_OPERATOR: - return new OrNode(children); + return new MultiNode(operator, children); default: throw new IllegalArgumentException(); } diff --git a/src/main/java/org/apache/accumulo/access/Authorizations.java b/src/main/java/org/apache/accumulo/access/Authorizations.java index 33f3f87..3d4df1b 100644 --- a/src/main/java/org/apache/accumulo/access/Authorizations.java +++ b/src/main/java/org/apache/accumulo/access/Authorizations.java @@ -22,7 +22,7 @@ import java.util.Collection; import java.util.Set; /** - * A collection of authorization strings. + * An immutable collection of authorization strings. * * @since 1.0.0 */ @@ -52,6 +52,11 @@ public class Authorizations { return authorizations.hashCode(); } + @Override + public String toString() { + return authorizations.toString(); + } + public static Authorizations of(String... authorizations) { return new Authorizations(Set.of(authorizations)); } diff --git a/src/main/java/org/apache/accumulo/access/BytesWrapper.java b/src/main/java/org/apache/accumulo/access/BytesWrapper.java index 998a452..4f93581 100644 --- a/src/main/java/org/apache/accumulo/access/BytesWrapper.java +++ b/src/main/java/org/apache/accumulo/access/BytesWrapper.java @@ -51,28 +51,11 @@ class BytesWrapper implements Comparable<BytesWrapper> { * @throws IllegalArgumentException if the offset or length are out of bounds for the given byte * array */ - public BytesWrapper(byte[] data, int offset, int length) { - - if (offset < 0 || offset > data.length) { - throw new IllegalArgumentException( - "Offset out of bounds. data.length = " + data.length + ", offset = " + offset); - } - if (length < 0) { - throw new IllegalArgumentException("Length cannot be negative. length = " + length); - } - if ((offset + length) > data.length) { - throw new IllegalArgumentException( - "Sum of offset and length exceeds data length. data.length = " + data.length - + ", offset = " + offset + ", length = " + length); - } - - this.data = data; - this.offset = offset; - this.length = length; - + BytesWrapper(byte[] data, int offset, int length) { + set(data, offset, length); } - public byte byteAt(int i) { + byte byteAt(int i) { if (i < 0) { throw new IllegalArgumentException("i < 0, " + i); @@ -131,4 +114,23 @@ class BytesWrapper implements Comparable<BytesWrapper> { public String toString() { return new String(data, offset, length, UTF_8); } + + void set(byte[] data, int offset, int length) { + if (offset < 0) { + throw new IllegalArgumentException("Offset cannot be negative. length = " + offset); + } + if (length < 0) { + throw new IllegalArgumentException("Length cannot be negative. length = " + length); + } + if ((offset + length) > data.length) { + throw new IllegalArgumentException( + "Sum of offset and length exceeds data length. data.length = " + data.length + + ", offset = " + offset + ", length = " + length); + } + + this.data = data; + this.offset = offset; + this.length = length; + } + } diff --git a/src/main/java/org/apache/accumulo/access/Parser.java b/src/main/java/org/apache/accumulo/access/Normalizer.java similarity index 59% rename from src/main/java/org/apache/accumulo/access/Parser.java rename to src/main/java/org/apache/accumulo/access/Normalizer.java index 17b7f7a..518d2fd 100644 --- a/src/main/java/org/apache/accumulo/access/Parser.java +++ b/src/main/java/org/apache/accumulo/access/Normalizer.java @@ -21,21 +21,14 @@ package org.apache.accumulo.access; import static org.apache.accumulo.access.ByteUtils.isAndOrOperator; import java.util.ArrayList; +import java.util.List; -/** - * Code for parsing an access expression and creating a parse tree of type {@link AeNode} - */ -final class Parser { - - public static final byte OPEN_PAREN = (byte) '('; - public static final byte CLOSE_PAREN = (byte) ')'; +class Normalizer { - public static AeNode parseAccessExpression(byte[] expression) { - - Tokenizer tokenizer = new Tokenizer(expression); + static String normalize(Tokenizer tokenizer) { if (!tokenizer.hasNext()) { - return AeNode.of(); + return ""; } var node = parseExpression(tokenizer); @@ -45,35 +38,36 @@ final class Parser { tokenizer.error("Unexpected character '" + (char) tokenizer.peek() + "'"); } - return node; + StringBuilder builder = new StringBuilder(); + node.normalize().stringify(builder, false); + return builder.toString(); } private static AeNode parseExpression(Tokenizer tokenizer) { - AeNode first = parseParenExpressionOrAuthorization(tokenizer); - - if (tokenizer.hasNext() && isAndOrOperator(tokenizer.peek())) { - var nodes = new ArrayList<AeNode>(); - nodes.add(first); + AeNode node = parseParenExpressionOrAuthorization(tokenizer); + if (tokenizer.hasNext()) { var operator = tokenizer.peek(); - - do { - tokenizer.advance(); - - nodes.add(parseParenExpressionOrAuthorization(tokenizer)); - - } while (tokenizer.hasNext() && tokenizer.peek() == operator); - - if (tokenizer.hasNext() && isAndOrOperator(tokenizer.peek())) { - // A case of mixed operators, lets give a clear error message - tokenizer.error("Cannot mix '|' and '&'"); + if (isAndOrOperator(operator)) { + List<AeNode> nodes = new ArrayList<>(); + nodes.add(node); + do { + tokenizer.advance(); + AeNode next = parseParenExpressionOrAuthorization(tokenizer); + nodes.add(next); + } while (tokenizer.hasNext() && tokenizer.peek() == operator); + + if (tokenizer.hasNext() && isAndOrOperator(tokenizer.peek())) { + // A case of mixed operators, lets give a clear error message + tokenizer.error("Cannot mix '|' and '&'"); + } + + node = AeNode.of(operator, nodes); } - - return AeNode.of(operator, nodes); - } else { - return first; } + + return node; } private static AeNode parseParenExpressionOrAuthorization(Tokenizer tokenizer) { @@ -82,13 +76,14 @@ final class Parser { .error("Expected a '(' character or an authorization token instead saw end of input"); } - if (tokenizer.peek() == OPEN_PAREN) { + if (tokenizer.peek() == ParserEvaluator.OPEN_PAREN) { tokenizer.advance(); var node = parseExpression(tokenizer); - tokenizer.next(CLOSE_PAREN); + tokenizer.next(ParserEvaluator.CLOSE_PAREN); return node; } else { - return AeNode.of(tokenizer.nextAuthorization()); + var auth = tokenizer.nextAuthorization(); + return AeNode.of(auth); } } } diff --git a/src/main/java/org/apache/accumulo/access/ParserEvaluator.java b/src/main/java/org/apache/accumulo/access/ParserEvaluator.java new file mode 100644 index 0000000..25cdd86 --- /dev/null +++ b/src/main/java/org/apache/accumulo/access/ParserEvaluator.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.accumulo.access; + +import static org.apache.accumulo.access.ByteUtils.isAndOrOperator; + +import java.util.function.Predicate; + +/** + * Code for parsing and evaluating an access expression at the same time. + */ +final class ParserEvaluator { + + static final byte OPEN_PAREN = (byte) '('; + static final byte CLOSE_PAREN = (byte) ')'; + + static boolean parseAccessExpression(Tokenizer tokenizer, + Predicate<Tokenizer.AuthorizationToken> authorizedPredicate, + Predicate<Tokenizer.AuthorizationToken> shortCircuitPredicate) { + + if (!tokenizer.hasNext()) { + return true; + } + + var node = parseExpression(tokenizer, authorizedPredicate, shortCircuitPredicate); + + if (tokenizer.hasNext()) { + // not all input was read, so not a valid expression + tokenizer.error("Unexpected character '" + (char) tokenizer.peek() + "'"); + } + + return node; + } + + private static boolean parseExpression(Tokenizer tokenizer, + Predicate<Tokenizer.AuthorizationToken> authorizedPredicate, + Predicate<Tokenizer.AuthorizationToken> shortCircuitPredicate) { + + boolean result = + parseParenExpressionOrAuthorization(tokenizer, authorizedPredicate, shortCircuitPredicate); + + if (tokenizer.hasNext()) { + var operator = tokenizer.peek(); + if (operator == ByteUtils.AND_OPERATOR) { + result = parseAndExpression(result, tokenizer, authorizedPredicate, shortCircuitPredicate); + if (tokenizer.hasNext() && isAndOrOperator(tokenizer.peek())) { + // A case of mixed operators, lets give a clear error message + tokenizer.error("Cannot mix '|' and '&'"); + } + } else if (operator == ByteUtils.OR_OPERATOR) { + result = parseOrExpression(result, tokenizer, authorizedPredicate, shortCircuitPredicate); + if (tokenizer.hasNext() && isAndOrOperator(tokenizer.peek())) { + // A case of mixed operators, lets give a clear error message + tokenizer.error("Cannot mix '|' and '&'"); + } + } + } + + return result; + } + + private static boolean parseAndExpression(boolean result, Tokenizer tokenizer, + Predicate<Tokenizer.AuthorizationToken> authorizedPredicate, + Predicate<Tokenizer.AuthorizationToken> shortCircuitPredicate) { + do { + if (!result) { + // Once the "and" expression is false, can avoid doing set lookups and only validate the + // rest of the expression. + authorizedPredicate = shortCircuitPredicate; + } + tokenizer.advance(); + var nextResult = parseParenExpressionOrAuthorization(tokenizer, authorizedPredicate, + shortCircuitPredicate); + result &= nextResult; + } while (tokenizer.hasNext() && tokenizer.peek() == ByteUtils.AND_OPERATOR); + return result; + } + + private static boolean parseOrExpression(boolean result, Tokenizer tokenizer, + Predicate<Tokenizer.AuthorizationToken> authorizedPredicate, + Predicate<Tokenizer.AuthorizationToken> shortCircuitPredicate) { + do { + if (result) { + // Once the "or" expression is true, can avoid doing set lookups and only validate the rest + // of the expression. + authorizedPredicate = shortCircuitPredicate; + } + tokenizer.advance(); + var nextResult = parseParenExpressionOrAuthorization(tokenizer, authorizedPredicate, + shortCircuitPredicate); + result |= nextResult; + } while (tokenizer.hasNext() && tokenizer.peek() == ByteUtils.OR_OPERATOR); + return result; + } + + private static boolean parseParenExpressionOrAuthorization(Tokenizer tokenizer, + Predicate<Tokenizer.AuthorizationToken> authorizedPredicate, + Predicate<Tokenizer.AuthorizationToken> shortCircuitPredicate) { + if (!tokenizer.hasNext()) { + tokenizer + .error("Expected a '(' character or an authorization token instead saw end of input"); + } + + if (tokenizer.peek() == OPEN_PAREN) { + tokenizer.advance(); + var node = parseExpression(tokenizer, authorizedPredicate, shortCircuitPredicate); + tokenizer.next(CLOSE_PAREN); + return node; + } else { + var auth = tokenizer.nextAuthorization(); + return authorizedPredicate.test(auth); + } + } +} diff --git a/src/main/java/org/apache/accumulo/access/Tokenizer.java b/src/main/java/org/apache/accumulo/access/Tokenizer.java index b5e3f14..1f44975 100644 --- a/src/main/java/org/apache/accumulo/access/Tokenizer.java +++ b/src/main/java/org/apache/accumulo/access/Tokenizer.java @@ -27,8 +27,8 @@ import java.util.stream.IntStream; /** * A simple wrapper around a byte array that keeps some state and provides high level operations to - * the {@link Parser} class. The purpose of this class is to make {@link Parser} as simple and easy - * to understand as possible while still being performant. + * the {@link ParserEvaluator} class. The purpose of this class is to make {@link ParserEvaluator} + * as simple and easy to understand as possible while still being performant. */ final class Tokenizer { @@ -49,7 +49,7 @@ final class Tokenizer { return validAuthChars[0xff & b]; } - private final byte[] expression; + private byte[] expression; private int index; private final AuthorizationToken authorizationToken = new AuthorizationToken(); @@ -65,6 +65,12 @@ final class Tokenizer { authorizationToken.data = expression; } + public void reset(byte[] expression) { + this.expression = expression; + authorizationToken.data = expression; + this.index = 0; + } + boolean hasNext() { return index < expression.length; } diff --git a/src/test/java/org/apache/accumulo/access/AccessEvaluatorTest.java b/src/test/java/org/apache/accumulo/access/AccessEvaluatorTest.java index 97d7119..a041b8a 100644 --- a/src/test/java/org/apache/accumulo/access/AccessEvaluatorTest.java +++ b/src/test/java/org/apache/accumulo/access/AccessEvaluatorTest.java @@ -29,6 +29,7 @@ import java.io.IOException; import java.lang.reflect.Type; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -113,30 +114,47 @@ public class AccessEvaluatorTest { assertTrue(tests.expressions.length > 0); for (var expression : tests.expressions) { + + // Call various APIs with well-formed access expressions to ensure they do not throw an + // exception + if (tests.expectedResult == ExpectedResult.ACCESSIBLE + || tests.expectedResult == ExpectedResult.INACCESSIBLE) { + AccessExpression.validate(expression); + AccessExpression.validate(expression.getBytes(UTF_8)); + assertEquals(expression, AccessExpression.of(expression).getExpression()); + assertEquals(expression, AccessExpression.of(expression, false).getExpression()); + assertEquals(expression, AccessExpression.of(expression.getBytes(UTF_8)).getExpression()); + assertEquals(expression, + AccessExpression.of(expression.getBytes(UTF_8), false).getExpression()); + Objects.requireNonNull(AccessExpression.of(expression).getAuthorizations()); + Objects + .requireNonNull(AccessExpression.of(expression.getBytes(UTF_8)).getAuthorizations()); + } + switch (tests.expectedResult) { case ACCESSIBLE: assertTrue(evaluator.canAccess(expression), expression); assertTrue(evaluator.canAccess(expression.getBytes(UTF_8)), expression); assertTrue(evaluator.canAccess(AccessExpression.of(expression)), expression); - assertTrue(evaluator.canAccess(AccessExpression.of(expression.getBytes(UTF_8))), + assertTrue(evaluator.canAccess(AccessExpression.of(expression, true)), expression); + assertTrue(evaluator.canAccess(AccessExpression.of(expression, true).getExpression()), expression); - assertTrue(evaluator.canAccess(AccessExpression.of(expression).normalize()), + assertTrue( + evaluator.canAccess( + AccessExpression.of(expression.getBytes(UTF_8), true).getExpression()), expression); - assertEquals(expression, - AccessExpression.of(expression.getBytes(UTF_8)).getExpression()); - assertEquals(expression, AccessExpression.of(expression).getExpression()); break; case INACCESSIBLE: assertFalse(evaluator.canAccess(expression), expression); assertFalse(evaluator.canAccess(expression.getBytes(UTF_8)), expression); assertFalse(evaluator.canAccess(AccessExpression.of(expression)), expression); - assertFalse(evaluator.canAccess(AccessExpression.of(expression.getBytes(UTF_8))), + assertFalse(evaluator.canAccess(AccessExpression.of(expression, true)), expression); + assertFalse(evaluator.canAccess(AccessExpression.of(expression, true).getExpression()), expression); - assertFalse(evaluator.canAccess(AccessExpression.of(expression).normalize()), + assertFalse( + evaluator.canAccess( + AccessExpression.of(expression.getBytes(UTF_8), true).getExpression()), expression); - assertEquals(expression, - AccessExpression.of(expression.getBytes(UTF_8)).getExpression()); - assertEquals(expression, AccessExpression.of(expression).getExpression()); break; case ERROR: assertThrows(IllegalAccessExpressionException.class, @@ -144,10 +162,21 @@ public class AccessEvaluatorTest { assertThrows(IllegalAccessExpressionException.class, () -> evaluator.canAccess(expression.getBytes(UTF_8)), expression); assertThrows(IllegalAccessExpressionException.class, - () -> evaluator.canAccess(AccessExpression.of(expression)), expression); + () -> AccessExpression.validate(expression), expression); assertThrows(IllegalAccessExpressionException.class, - () -> evaluator.canAccess(AccessExpression.of(expression.getBytes(UTF_8))), - expression); + () -> AccessExpression.validate(expression.getBytes(UTF_8)), expression); + assertThrows(IllegalAccessExpressionException.class, + () -> AccessExpression.of(expression), expression); + assertThrows(IllegalAccessExpressionException.class, + () -> AccessExpression.of(expression, false), expression); + assertThrows(IllegalAccessExpressionException.class, + () -> AccessExpression.of(expression, true), expression); + assertThrows(IllegalAccessExpressionException.class, + () -> AccessExpression.of(expression.getBytes(UTF_8)), expression); + assertThrows(IllegalAccessExpressionException.class, + () -> AccessExpression.of(expression.getBytes(UTF_8), false), expression); + assertThrows(IllegalAccessExpressionException.class, + () -> AccessExpression.of(expression.getBytes(UTF_8), true), expression); break; default: throw new IllegalArgumentException(); diff --git a/src/test/java/org/apache/accumulo/access/AccessExpressionBenchmark.java b/src/test/java/org/apache/accumulo/access/AccessExpressionBenchmark.java index b04b789..af9a14b 100644 --- a/src/test/java/org/apache/accumulo/access/AccessExpressionBenchmark.java +++ b/src/test/java/org/apache/accumulo/access/AccessExpressionBenchmark.java @@ -58,7 +58,6 @@ public class AccessExpressionBenchmark { public static class EvaluatorTests { AccessEvaluator evaluator; - List<AccessExpression> parsedExpressions; List<byte[]> expressions; } @@ -83,7 +82,6 @@ public class AccessExpressionBenchmark { for (var testDataSet : testData) { EvaluatorTests et = new EvaluatorTests(); - et.parsedExpressions = new ArrayList<>(); et.expressions = new ArrayList<>(); if (testDataSet.auths.length == 1) { @@ -101,7 +99,6 @@ public class AccessExpressionBenchmark { byte[] byteExp = exp.getBytes(UTF_8); allTestExpressions.add(byteExp); et.expressions.add(byteExp); - et.parsedExpressions.add(AccessExpression.of(exp)); } } } @@ -128,9 +125,9 @@ public class AccessExpressionBenchmark { * Measures the time it takes to parse an expression stored in byte[] and produce a parse tree. */ @Benchmark - public void measureBytesParsing(BenchmarkState state, Blackhole blackhole) { + public void measureBytesValidation(BenchmarkState state, Blackhole blackhole) { for (byte[] accessExpression : state.getBytesExpressions()) { - blackhole.consume(AccessExpression.of(accessExpression)); + AccessExpression.validate(accessExpression); } } @@ -138,21 +135,9 @@ public class AccessExpressionBenchmark { * Measures the time it takes to parse an expression stored in a String and produce a parse tree. */ @Benchmark - public void measureStringParsing(BenchmarkState state, Blackhole blackhole) { + public void measureStringValidation(BenchmarkState state, Blackhole blackhole) { for (String accessExpression : state.getStringExpressions()) { - blackhole.consume(AccessExpression.of(accessExpression)); - } - } - - /** - * Measures the time it takes to evaluate a previously parsed expression. - */ - @Benchmark - public void measureEvaluation(BenchmarkState state, Blackhole blackhole) { - for (EvaluatorTests evaluatorTests : state.getEvaluatorTests()) { - for (AccessExpression expression : evaluatorTests.parsedExpressions) { - blackhole.consume(evaluatorTests.evaluator.canAccess(expression)); - } + AccessExpression.validate(accessExpression); } } @@ -161,7 +146,7 @@ public class AccessExpressionBenchmark { * tree an operate on it. */ @Benchmark - public void measureEvaluationAndParsing(BenchmarkState state, Blackhole blackhole) { + public void measureEvaluation(BenchmarkState state, Blackhole blackhole) { for (EvaluatorTests evaluatorTests : state.getEvaluatorTests()) { for (byte[] expression : evaluatorTests.expressions) { blackhole.consume(evaluatorTests.evaluator.canAccess(expression)); @@ -182,8 +167,6 @@ public class AccessExpressionBenchmark { .mode(Mode.Throughput).operationsPerInvocation(numExpressions) .timeUnit(TimeUnit.MICROSECONDS).warmupTime(TimeValue.seconds(5)).warmupIterations(3) .measurementIterations(4).forks(3).build(); - new Runner(opt).run(); } - } diff --git a/src/test/java/org/apache/accumulo/access/AccessExpressionTest.java b/src/test/java/org/apache/accumulo/access/AccessExpressionTest.java index bf76ac9..266be77 100644 --- a/src/test/java/org/apache/accumulo/access/AccessExpressionTest.java +++ b/src/test/java/org/apache/accumulo/access/AccessExpressionTest.java @@ -20,6 +20,7 @@ package org.apache.accumulo.access; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -28,14 +29,10 @@ import java.util.List; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; public class AccessExpressionTest { - @Test - public void testEmptyExpression() { - assertEquals("", AccessExpression.of().getExpression()); - } - @Test public void testGetAuthorizations() { // Test data pairs where the first entry of each pair is an expression to normalize and second @@ -73,6 +70,7 @@ public class AccessExpressionTest { testData.add(List.of("", "")); testData.add(List.of("a", "a")); + testData.add(List.of("\"a\"", "a")); testData.add(List.of("(a)", "a")); testData.add(List.of("b|a", "a|b")); testData.add(List.of("(b)|a", "a|b")); @@ -103,21 +101,42 @@ public class AccessExpressionTest { testData.add(List.of("a&a&a&a", "a")); testData.add(List.of("(a|a)|(a|a)", "a")); testData.add(List.of("(a&a)&(a&a)", "a")); + var auth1 = "\"ABC\""; + var auth2 = "\"QRS\""; + var auth3 = "\"X&Z\""; + testData.add(List.of( + "(" + auth1 + "&" + auth2 + "&" + auth3 + ")|(" + auth3 + "&" + auth1 + "&" + auth2 + ")", + "ABC&QRS&\"X&Z\"")); for (var testCase : testData) { assertEquals(2, testCase.size()); var expression = testCase.get(0); var expected = testCase.get(1); - var normalized = AccessExpression.of(expression).normalize(); - assertEquals(expected, normalized); - assertEquals(expected, AccessExpression.of(expression.getBytes(UTF_8)).normalize()); - assertEquals(expected, AccessExpression.of(normalized).normalize()); + assertEquals(expected, AccessExpression.of(expression, true).getExpression()); + assertEquals(expected, AccessExpression.of(expression.getBytes(UTF_8), true).getExpression()); + + // when not normalizing should see the original expression + assertEquals(expression, AccessExpression.of(expression).getExpression()); + assertEquals(expression, AccessExpression.of(expression, false).getExpression()); + assertEquals(expression, AccessExpression.of(expression.getBytes(UTF_8)).getExpression()); + assertEquals(expression, + AccessExpression.of(expression.getBytes(UTF_8), false).getExpression()); } } void checkError(String expression, String expected, int index) { - var exception = - assertThrows(IllegalAccessExpressionException.class, () -> AccessExpression.of(expression)); + checkError(() -> AccessExpression.validate(expression), expected, index); + checkError(() -> AccessExpression.validate(expression.getBytes(UTF_8)), expected, index); + checkError(() -> AccessExpression.of(expression), expected, index); + checkError(() -> AccessExpression.of(expression, true), expected, index); + checkError(() -> AccessExpression.of(expression, false), expected, index); + checkError(() -> AccessExpression.of(expression.getBytes(UTF_8)), expected, index); + checkError(() -> AccessExpression.of(expression.getBytes(UTF_8), true), expected, index); + checkError(() -> AccessExpression.of(expression.getBytes(UTF_8), false), expected, index); + } + + void checkError(Executable executable, String expected, int index) { + var exception = assertThrows(IllegalAccessExpressionException.class, executable); assertTrue(exception.getMessage().contains(expected)); assertEquals(index, exception.getIndex()); } @@ -156,4 +175,18 @@ public class AccessExpressionTest { checkError("\"\\9\"", "Invalid escaping within quotes", 1); checkError("ERR&\"\\9\"", "Invalid escaping within quotes", 5); } + + @Test + public void testEqualsHashcode() { + var ae1 = AccessExpression.of("A&B"); + var ae2 = AccessExpression.of("A&C"); + var ae3 = AccessExpression.of("B&A", true); + + assertEquals(ae1, ae3); + assertNotEquals(ae1, ae2); + assertNotEquals(ae3, ae2); + + assertEquals(ae1.hashCode(), ae3.hashCode()); + assertNotEquals(ae1.hashCode(), ae2.hashCode()); + } }