This is an automated email from the ASF dual-hosted git repository.

kusal pushed a commit to branch WW-5350-allowlist-2
in repository https://gitbox.apache.org/repos/asf/struts.git

commit bef976917013090ef15124f8996344a868d835d1
Author: Kusal Kithul-Godage <g...@kusal.io>
AuthorDate: Sun Nov 5 23:05:35 2023 +1100

    WW-5350 Implement OGNL Allowlist capability
---
 .../com/opensymphony/xwork2/ognl/OgnlUtil.java     | 78 ++++++++++++++++++----
 .../opensymphony/xwork2/ognl/OgnlValueStack.java   |  3 +
 .../xwork2/ognl/SecurityMemberAccess.java          | 56 ++++++++++++++--
 .../java/org/apache/struts2/StrutsConstants.java   |  6 ++
 .../xwork2/ognl/SecurityMemberAccessTest.java      | 77 +++++++++++++++++++++
 5 files changed, 203 insertions(+), 17 deletions(-)

diff --git a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java 
b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java
index 7c9b5fbbb..1f019f64a 100644
--- a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java
+++ b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java
@@ -59,6 +59,9 @@ import static java.util.Collections.unmodifiableSet;
 import static java.util.Objects.requireNonNull;
 import static java.util.stream.Collectors.toSet;
 import static org.apache.commons.lang3.StringUtils.strip;
+import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_CLASSES;
+import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_ENABLE;
+import static 
org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_PACKAGE_NAMES;
 import static org.apache.struts2.ognl.OgnlGuard.EXPR_BLOCKED;
 
 
@@ -89,6 +92,10 @@ public class OgnlUtil {
     private Set<String> excludedPackageNames = emptySet();
     private Set<String> excludedPackageExemptClasses = emptySet();
 
+    private boolean enforceAllowlistEnabled = false;
+    private Set<String> allowlistClasses = emptySet();
+    private Set<String> allowlistPackageNames = emptySet();
+
     private Set<String> devModeExcludedClasses = emptySet();
     private Set<Pattern> devModeExcludedPackageNamePatterns = emptySet();
     private Set<String> devModeExcludedPackageNames = emptySet();
@@ -96,8 +103,8 @@ public class OgnlUtil {
 
     private Container container;
     private boolean allowStaticFieldAccess = true;
-    private boolean disallowProxyMemberAccess;
-    private boolean disallowDefaultPackageAccess;
+    private boolean disallowProxyMemberAccess = false;
+    private boolean disallowDefaultPackageAccess = false;
 
     /**
      * Construct a new OgnlUtil instance for use with the framework
@@ -178,6 +185,12 @@ public class OgnlUtil {
         devModeExcludedClasses = toNewClassesSet(devModeExcludedClasses, 
commaDelimitedClasses);
     }
 
+    private static Set<String> toClassesSet(String newDelimitedClasses) throws 
ConfigurationException {
+        Set<String> classNames = 
commaDelimitedStringToSet(newDelimitedClasses);
+        validateClasses(classNames, OgnlUtil.class.getClassLoader());
+        return unmodifiableSet(classNames);
+    }
+
     private static Set<String> toNewClassesSet(Set<String> oldClasses, String 
newDelimitedClasses) throws ConfigurationException {
         Set<String> classNames = 
commaDelimitedStringToSet(newDelimitedClasses);
         validateClasses(classNames, OgnlUtil.class.getClassLoader());
@@ -229,25 +242,36 @@ public class OgnlUtil {
         devModeExcludedPackageNames = 
toNewPackageNamesSet(devModeExcludedPackageNames, commaDelimitedPackageNames);
     }
 
-    private static Set<String> toNewPackageNamesSet(Set<String> 
oldPackageNames, String newDelimitedPackageNames) throws ConfigurationException 
{
+    private static Set<String> toPackageNamesSet(String 
newDelimitedPackageNames) throws ConfigurationException {
         Set<String> packageNames = 
commaDelimitedStringToSet(newDelimitedPackageNames)
                 .stream().map(s -> strip(s, ".")).collect(toSet());
-        if (packageNames.stream().anyMatch(s -> 
Pattern.compile("\\s").matcher(s).find())) {
-            throw new ConfigurationException("Excluded package names could not 
be parsed due to erroneous whitespace characters: " + newDelimitedPackageNames);
-        }
+        validatePackageNames(packageNames);
+        return unmodifiableSet(packageNames);
+    }
+
+    private static Set<String> toNewPackageNamesSet(Collection<String> 
oldPackageNames, String newDelimitedPackageNames) throws ConfigurationException 
{
+        Set<String> packageNames = 
commaDelimitedStringToSet(newDelimitedPackageNames)
+                .stream().map(s -> strip(s, ".")).collect(toSet());
+        validatePackageNames(packageNames);
         Set<String> newPackageNames = new HashSet<>(oldPackageNames);
         newPackageNames.addAll(packageNames);
         return unmodifiableSet(newPackageNames);
     }
 
+    private static void validatePackageNames(Collection<String> packageNames) {
+        if (packageNames.stream().anyMatch(s -> 
Pattern.compile("\\s").matcher(s).find())) {
+            throw new ConfigurationException("Excluded package names could not 
be parsed due to erroneous whitespace characters: " + packageNames);
+        }
+    }
+
     @Inject(value = StrutsConstants.STRUTS_EXCLUDED_PACKAGE_EXEMPT_CLASSES, 
required = false)
     public void setExcludedPackageExemptClasses(String commaDelimitedClasses) {
-        excludedPackageExemptClasses = 
toNewClassesSet(excludedPackageExemptClasses, commaDelimitedClasses);
+        excludedPackageExemptClasses = toClassesSet(commaDelimitedClasses);
     }
 
     @Inject(value = 
StrutsConstants.STRUTS_DEV_MODE_EXCLUDED_PACKAGE_EXEMPT_CLASSES, required = 
false)
     public void setDevModeExcludedPackageExemptClasses(String 
commaDelimitedClasses) {
-        devModeExcludedPackageExemptClasses = 
toNewClassesSet(devModeExcludedPackageExemptClasses, commaDelimitedClasses);
+        devModeExcludedPackageExemptClasses = 
toClassesSet(commaDelimitedClasses);
     }
 
     public Set<String> getExcludedClasses() {
@@ -266,6 +290,33 @@ public class OgnlUtil {
         return excludedPackageExemptClasses;
     }
 
+    @Inject(value = STRUTS_ALLOWLIST_ENABLE, required = false)
+    protected void setEnforceAllowlistEnabled(String enforceAllowlistEnabled) {
+        this.enforceAllowlistEnabled = 
BooleanUtils.toBoolean(enforceAllowlistEnabled);
+    }
+
+    @Inject(value = STRUTS_ALLOWLIST_CLASSES, required = false)
+    protected void setAllowlistClasses(String commaDelimitedClasses) {
+        allowlistClasses = toClassesSet(commaDelimitedClasses);
+    }
+
+    @Inject(value = STRUTS_ALLOWLIST_PACKAGE_NAMES, required = false)
+    protected void setAllowlistPackageNames(String commaDelimitedPackageNames) 
{
+        allowlistPackageNames = toPackageNamesSet(commaDelimitedPackageNames);
+    }
+
+    public boolean isEnforceAllowlistEnabled() {
+        return enforceAllowlistEnabled;
+    }
+
+    public Set<String> getAllowlistClasses() {
+        return allowlistClasses;
+    }
+
+    public Set<String> getAllowlistPackageNames() {
+        return allowlistPackageNames;
+    }
+
     @Inject
     protected void setContainer(Container container) {
         this.container = container;
@@ -884,10 +935,13 @@ public class OgnlUtil {
             memberAccess.useExcludedPackageNames(devModeExcludedPackageNames);
             
memberAccess.useExcludedPackageExemptClasses(devModeExcludedPackageExemptClasses);
         } else {
-            memberAccess.useExcludedClasses(excludedClasses);
-            
memberAccess.useExcludedPackageNamePatterns(excludedPackageNamePatterns);
-            memberAccess.useExcludedPackageNames(excludedPackageNames);
-            
memberAccess.useExcludedPackageExemptClasses(excludedPackageExemptClasses);
+            memberAccess.useExcludedClasses(getExcludedClasses());
+            
memberAccess.useExcludedPackageNamePatterns(getExcludedPackageNamePatterns());
+            memberAccess.useExcludedPackageNames(getExcludedPackageNames());
+            
memberAccess.useExcludedPackageExemptClasses(getExcludedPackageExemptClasses());
+            
memberAccess.useEnforceAllowlistEnabled(isEnforceAllowlistEnabled());
+            memberAccess.useAllowlistClasses(getAllowlistClasses());
+            memberAccess.useAllowlistPackageNames(getAllowlistPackageNames());
         }
 
         return Ognl.createDefaultContext(root, memberAccess, resolver, 
defaultConverter);
diff --git 
a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java 
b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java
index 936619ae4..01b6af81d 100644
--- a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java
+++ b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java
@@ -93,6 +93,9 @@ public class OgnlValueStack implements Serializable, 
ValueStack, ClearableValueS
         
securityMemberAccess.useExcludedPackageNamePatterns(ognlUtil.getExcludedPackageNamePatterns());
         
securityMemberAccess.useExcludedPackageNames(ognlUtil.getExcludedPackageNames());
         
securityMemberAccess.useExcludedPackageExemptClasses(ognlUtil.getExcludedPackageExemptClasses());
+        
securityMemberAccess.useEnforceAllowlistEnabled(ognlUtil.isEnforceAllowlistEnabled());
+        
securityMemberAccess.useAllowlistClasses(ognlUtil.getAllowlistClasses());
+        
securityMemberAccess.useAllowlistPackageNames(ognlUtil.getAllowlistPackageNames());
         
securityMemberAccess.disallowProxyMemberAccess(ognlUtil.isDisallowProxyMemberAccess());
         
securityMemberAccess.disallowDefaultPackageAccess(ognlUtil.isDisallowDefaultPackageAccess());
     }
diff --git 
a/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java 
b/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java
index e993f3adb..beb45a2dd 100644
--- a/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java
+++ b/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java
@@ -56,8 +56,11 @@ public class SecurityMemberAccess implements MemberAccess {
     private Set<Pattern> excludedPackageNamePatterns = emptySet();
     private Set<String> excludedPackageNames = emptySet();
     private Set<String> excludedPackageExemptClasses = emptySet();
-    private boolean disallowProxyMemberAccess;
-    private boolean disallowDefaultPackageAccess;
+    private boolean enforceAllowlistEnabled = false;
+    private Set<String> allowlistClasses = emptySet();
+    private Set<String> allowlistPackageNames = emptySet();
+    private boolean disallowProxyMemberAccess = false;
+    private boolean disallowDefaultPackageAccess = false;
 
     /**
      * SecurityMemberAccess
@@ -149,6 +152,10 @@ public class SecurityMemberAccess implements MemberAccess {
             return false;
         }
 
+        if (!checkAllowlist(target, member)) {
+            return false;
+        }
+
         if (!isAcceptableProperty(propertyName)) {
             return false;
         }
@@ -156,6 +163,33 @@ public class SecurityMemberAccess implements MemberAccess {
         return true;
     }
 
+    /**
+     * @return {@code true} if member access is allowed
+     */
+    protected boolean checkAllowlist(Object target, Member member) {
+        Class<?> memberClass = member.getDeclaringClass();
+        if (!enforceAllowlistEnabled) {
+            return true;
+        }
+        if (!isClassAllowlisted(memberClass)) {
+            LOG.warn(format("Declaring class [{0}] of member type [{1}] is not 
allowlisted!", memberClass, member));
+            return false;
+        }
+        if (target == null || target.getClass() == memberClass) {
+            return true;
+        }
+        Class<?> targetClass = target.getClass();
+        if (!isClassAllowlisted(targetClass)) {
+            LOG.warn(format("Target class [{0}] of target [{1}] is not 
allowlisted!", targetClass, target));
+            return false;
+        }
+        return true;
+    }
+
+    protected boolean isClassAllowlisted(Class<?> clazz) {
+        return allowlistClasses.contains(clazz.getName()) || 
isClassBelongsToPackages(clazz, allowlistPackageNames);
+    }
+
     /**
      * @return {@code true} if member access is allowed
      */
@@ -286,14 +320,14 @@ public class SecurityMemberAccess implements MemberAccess 
{
     }
 
     protected boolean isExcludedPackageNames(Class<?> clazz) {
-        return isExcludedPackageNamesStatic(clazz, excludedPackageNames);
+        return isClassBelongsToPackages(clazz, excludedPackageNames);
     }
 
-    public static boolean isExcludedPackageNamesStatic(Class<?> clazz, 
Set<String> excludedPackageNames) {
+    public static boolean isClassBelongsToPackages(Class<?> clazz, Set<String> 
matchingPackages) {
         List<String> packageParts = 
Arrays.asList(toPackageName(clazz).split("\\."));
         for (int i = 0; i < packageParts.size(); i++) {
             String parentPackage = String.join(".", packageParts.subList(0, i 
+ 1));
-            if (excludedPackageNames.contains(parentPackage)) {
+            if (matchingPackages.contains(parentPackage)) {
                 return true;
             }
         }
@@ -399,6 +433,18 @@ public class SecurityMemberAccess implements MemberAccess {
         this.excludedPackageExemptClasses = excludedPackageExemptClasses;
     }
 
+    public void useEnforceAllowlistEnabled(boolean enforceAllowlistEnabled) {
+        this.enforceAllowlistEnabled = enforceAllowlistEnabled;
+    }
+
+    public void useAllowlistClasses(Set<String> allowlistClasses) {
+        this.allowlistClasses = allowlistClasses;
+    }
+
+    public void useAllowlistPackageNames(Set<String> allowlistPackageNames) {
+        this.allowlistPackageNames = allowlistPackageNames;
+    }
+
     /**
      * @deprecated please use {@link #disallowProxyMemberAccess(boolean)}
      */
diff --git a/core/src/main/java/org/apache/struts2/StrutsConstants.java 
b/core/src/main/java/org/apache/struts2/StrutsConstants.java
index bfe639d2e..bcb07a69a 100644
--- a/core/src/main/java/org/apache/struts2/StrutsConstants.java
+++ b/core/src/main/java/org/apache/struts2/StrutsConstants.java
@@ -432,6 +432,12 @@ public final class StrutsConstants {
     public static final String STRUTS_DEV_MODE_EXCLUDED_PACKAGE_NAMES = 
"struts.devMode.excludedPackageNames";
     public static final String STRUTS_DEV_MODE_EXCLUDED_PACKAGE_EXEMPT_CLASSES 
= "struts.devMode.excludedPackageExemptClasses";
 
+    /** Boolean to enable strict allowlist processing of all OGNL expression 
calls. */
+    public static final String STRUTS_ALLOWLIST_ENABLE = 
"struts.allowlist.enable";
+    /** Comma delimited set of allowed classes which CAN be accessed via OGNL 
expressions. Both target and member classes of OGNL expression must be 
allowlisted. */
+    public static final String STRUTS_ALLOWLIST_CLASSES = 
"struts.allowlist.classes";
+    /** Comma delimited set of package names, of which all its classes, and 
all classes in its subpackages, CAN be accessed via OGNL expressions. Both 
target and member classes of OGNL expression must be allowlisted. */
+    public static final String STRUTS_ALLOWLIST_PACKAGE_NAMES = 
"struts.allowlist.packageNames";
 
     /** Dedicated services to check if passed string is excluded/accepted */
     public static final String STRUTS_EXCLUDED_PATTERNS_CHECKER = 
"struts.excludedPatterns.checker";
diff --git 
a/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java 
b/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java
index 7ffc47b9f..08a3b919e 100644
--- 
a/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java
+++ 
b/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java
@@ -18,12 +18,15 @@
  */
 package com.opensymphony.xwork2.ognl;
 
+import com.opensymphony.xwork2.TestBean;
+import com.opensymphony.xwork2.test.TestBean2;
 import com.opensymphony.xwork2.util.TextParseUtil;
 import org.junit.Before;
 import org.junit.Test;
 
 import java.lang.reflect.Field;
 import java.lang.reflect.Member;
+import java.lang.reflect.Method;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -32,6 +35,7 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.regex.Pattern;
 
+import static java.util.Arrays.asList;
 import static java.util.Collections.singletonList;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -794,6 +798,79 @@ public class SecurityMemberAccessTest {
         assertTrue("package java.lang. is accessible!", actual);
     }
 
+    @Test
+    public void classInclusion() throws Exception {
+
+        sma.useEnforceAllowlistEnabled(true);
+
+        TestBean2 bean = new TestBean2();
+        Method method = TestBean2.class.getMethod("getData");
+
+        assertFalse(sma.checkAllowlist(bean, method));
+
+        sma.useAllowlistClasses(new 
HashSet<>(singletonList(TestBean2.class.getName())));
+
+        assertTrue(sma.checkAllowlist(bean, method));
+    }
+
+    @Test
+    public void packageInclusion() throws Exception {
+        sma.useEnforceAllowlistEnabled(true);
+
+        TestBean2 bean = new TestBean2();
+        Method method = TestBean2.class.getMethod("getData");
+
+        assertFalse(sma.checkAllowlist(bean, method));
+
+        sma.useAllowlistPackageNames(new 
HashSet<>(singletonList(TestBean2.class.getPackage().getName())));
+
+        assertTrue(sma.checkAllowlist(bean, method));
+    }
+
+    @Test
+    public void classInclusion_subclass() throws Exception {
+        sma.useEnforceAllowlistEnabled(true);
+        sma.useAllowlistClasses(new 
HashSet<>(singletonList(TestBean2.class.getName())));
+
+        TestBean2 bean = new TestBean2();
+        Method method = TestBean2.class.getMethod("getName");
+
+        assertFalse(sma.checkAllowlist(bean, method));
+    }
+
+    @Test
+    public void classInclusion_subclass_both() throws Exception {
+        sma.useEnforceAllowlistEnabled(true);
+        sma.useAllowlistClasses(new HashSet<>(asList(TestBean.class.getName(), 
TestBean2.class.getName())));
+
+        TestBean2 bean = new TestBean2();
+        Method method = TestBean2.class.getMethod("getName");
+
+        assertTrue(sma.checkAllowlist(bean, method));
+    }
+
+    @Test
+    public void packageInclusion_subclass() throws Exception {
+        sma.useEnforceAllowlistEnabled(true);
+        sma.useAllowlistPackageNames(new 
HashSet<>(singletonList(TestBean2.class.getPackage().getName())));
+
+        TestBean2 bean = new TestBean2();
+        Method method = TestBean2.class.getMethod("getName");
+
+        assertFalse(sma.checkAllowlist(bean, method));
+    }
+
+    @Test
+    public void packageInclusion_subclass_both() throws Exception {
+        sma.useEnforceAllowlistEnabled(true);
+        sma.useAllowlistPackageNames(new 
HashSet<>(asList(TestBean.class.getPackage().getName(), 
TestBean2.class.getPackage().getName())));
+
+        TestBean2 bean = new TestBean2();
+        Method method = TestBean2.class.getMethod("getName");
+
+        assertTrue(sma.checkAllowlist(bean, method));
+    }
+
     private static String formGetterName(String propertyName) {
         return "get" + propertyName.substring(0, 1).toUpperCase() + 
propertyName.substring(1);
     }

Reply via email to