This is an automated email from the ASF dual-hosted git repository.

nbonte pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/atlas.git

commit ff2d63e38dcca51a60816d04acf611e7f7c1a260
Author: nixonrodrigues <ni...@apache.org>
AuthorDate: Thu May 20 21:57:00 2021 -0700

    ATLAS-4064: Atlas HEADER validation
    
    Signed-off-by: Nikhil Bonte <nbo...@apache.org>
    (cherry picked from commit 4691650dfe26f13884483bba6025cb66f4f818da)
---
 dashboardv2/public/js/utils/CommonViewFunction.js  | 30 +++++++---------
 dashboardv3/public/js/utils/CommonViewFunction.js  | 29 +++++++--------
 .../web/filters/AtlasCSRFPreventionFilter.java     | 41 ++++++++++++++++------
 .../apache/atlas/web/resources/AdminResource.java  | 18 +++++++---
 .../web/filters/AtlasCSRFPreventionFilterTest.java | 31 ++++++++++++++++
 5 files changed, 99 insertions(+), 50 deletions(-)

diff --git a/dashboardv2/public/js/utils/CommonViewFunction.js 
b/dashboardv2/public/js/utils/CommonViewFunction.js
index 80db527..bb3fa3f 100644
--- a/dashboardv2/public/js/utils/CommonViewFunction.js
+++ b/dashboardv2/public/js/utils/CommonViewFunction.js
@@ -793,7 +793,6 @@ define(['require', 'utils/Utils', 'modules/Modal', 
'utils/Messages', 'utils/Enum
                 }));
             }
         }
-
     }
     CommonViewFunction.removeCategoryTermAssociation = function(options) {
         if (options) {
@@ -864,13 +863,10 @@ define(['require', 'utils/Utils', 'modules/Modal', 
'utils/Messages', 'utils/Enum
         }
     }
     CommonViewFunction.addRestCsrfCustomHeader = function(xhr, settings) {
-        if (settings.url == null) {
-            return;
-        }
-        var method = settings.type;
-        if (CommonViewFunction.restCsrfCustomHeader != null && 
!CommonViewFunction.restCsrfMethodsToIgnore[method]) {
-            // The value of the header is unimportant.  Only its presence 
matters.
-            xhr.setRequestHeader(CommonViewFunction.restCsrfCustomHeader, 
'""');
+        if (null != settings.url) {
+            var method = settings.type;
+            var csrfToken = CommonViewFunction.restCsrfValue;
+            null == CommonViewFunction.restCsrfCustomHeader || 
CommonViewFunction.restCsrfMethodsToIgnore[method] || 
xhr.setRequestHeader(CommonViewFunction.restCsrfCustomHeader, csrfToken);
         }
     }
     CommonViewFunction.restCsrfCustomHeader = null;
@@ -900,16 +896,14 @@ define(['require', 'utils/Utils', 'modules/Modal', 
'utils/Messages', 'utils/Enum
                             var str = "" + response['atlas.rest-csrf.enabled'];
                             csrfEnabled = (str.toLowerCase() == 'true');
                         }
-                        if (response['atlas.rest-csrf.custom-header']) {
-                            header = 
response['atlas.rest-csrf.custom-header'].trim();
-                        }
-                        if (response['atlas.rest-csrf.methods-to-ignore']) {
-                            methods = 
getTrimmedStringArrayValue(response['atlas.rest-csrf.methods-to-ignore']);
-                        }
-                        if (csrfEnabled) {
-                            CommonViewFunction.restCsrfCustomHeader = header;
-                            CommonViewFunction.restCsrfMethodsToIgnore = {};
-                            methods.map(function(method) { 
CommonViewFunction.restCsrfMethodsToIgnore[method] = true; });
+                        if (response["atlas.rest-csrf.custom-header"] && 
(header = response["atlas.rest-csrf.custom-header"].trim()),
+                            response["atlas.rest-csrf.methods-to-ignore"] && 
(methods = 
getTrimmedStringArrayValue(response["atlas.rest-csrf.methods-to-ignore"])),
+                            csrfEnabled) {
+                            CommonViewFunction.restCsrfCustomHeader = header, 
CommonViewFunction.restCsrfMethodsToIgnore = {},
+                                CommonViewFunction.restCsrfValue = 
response["_csrfToken"] || '""',
+                                methods.map(function(method) {
+                                    
CommonViewFunction.restCsrfMethodsToIgnore[method] = !0;
+                                });
                             var statusCodeErrorFn = function(error) {
                                 Utils.defaultErrorHandler(null, error)
                             }
diff --git a/dashboardv3/public/js/utils/CommonViewFunction.js 
b/dashboardv3/public/js/utils/CommonViewFunction.js
index 14a8b74..34afa2d 100644
--- a/dashboardv3/public/js/utils/CommonViewFunction.js
+++ b/dashboardv3/public/js/utils/CommonViewFunction.js
@@ -884,13 +884,10 @@ define(['require', 'utils/Utils', 'modules/Modal', 
'utils/Messages', 'utils/Enum
         }
     }
     CommonViewFunction.addRestCsrfCustomHeader = function(xhr, settings) {
-        if (settings.url == null) {
-            return;
-        }
-        var method = settings.type;
-        if (CommonViewFunction.restCsrfCustomHeader != null && 
!CommonViewFunction.restCsrfMethodsToIgnore[method]) {
-            // The value of the header is unimportant.  Only its presence 
matters.
-            xhr.setRequestHeader(CommonViewFunction.restCsrfCustomHeader, 
'""');
+        if (null != settings.url) {
+            var method = settings.type;
+            var csrfToken = CommonViewFunction.restCsrfValue;
+            null == CommonViewFunction.restCsrfCustomHeader || 
CommonViewFunction.restCsrfMethodsToIgnore[method] || 
xhr.setRequestHeader(CommonViewFunction.restCsrfCustomHeader, csrfToken);
         }
     }
     CommonViewFunction.restCsrfCustomHeader = null;
@@ -920,16 +917,14 @@ define(['require', 'utils/Utils', 'modules/Modal', 
'utils/Messages', 'utils/Enum
                             var str = "" + response['atlas.rest-csrf.enabled'];
                             csrfEnabled = (str.toLowerCase() == 'true');
                         }
-                        if (response['atlas.rest-csrf.custom-header']) {
-                            header = 
response['atlas.rest-csrf.custom-header'].trim();
-                        }
-                        if (response['atlas.rest-csrf.methods-to-ignore']) {
-                            methods = 
getTrimmedStringArrayValue(response['atlas.rest-csrf.methods-to-ignore']);
-                        }
-                        if (csrfEnabled) {
-                            CommonViewFunction.restCsrfCustomHeader = header;
-                            CommonViewFunction.restCsrfMethodsToIgnore = {};
-                            methods.map(function(method) { 
CommonViewFunction.restCsrfMethodsToIgnore[method] = true; });
+                        if (response["atlas.rest-csrf.custom-header"] && 
(header = response["atlas.rest-csrf.custom-header"].trim()),
+                            response["atlas.rest-csrf.methods-to-ignore"] && 
(methods = 
getTrimmedStringArrayValue(response["atlas.rest-csrf.methods-to-ignore"])),
+                            csrfEnabled) {
+                            CommonViewFunction.restCsrfCustomHeader = header, 
CommonViewFunction.restCsrfMethodsToIgnore = {},
+                                CommonViewFunction.restCsrfValue = 
response["_csrfToken"] || '""',
+                                methods.map(function(method) {
+                                    
CommonViewFunction.restCsrfMethodsToIgnore[method] = !0;
+                                });
                             var statusCodeErrorFn = function(error) {
                                 Utils.defaultErrorHandler(null, error)
                             }
diff --git 
a/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java
 
b/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java
index df3fce6..429ff1c 100644
--- 
a/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java
+++ 
b/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java
@@ -21,6 +21,7 @@ package org.apache.atlas.web.filters;
 import org.apache.atlas.ApplicationProperties;
 import org.apache.atlas.AtlasException;
 import org.apache.commons.configuration.Configuration;
+import org.apache.commons.lang.StringUtils;
 import org.json.simple.JSONObject;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -34,6 +35,7 @@ import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.HashSet;
@@ -63,11 +65,13 @@ public class AtlasCSRFPreventionFilter implements Filter {
        public static final String CUSTOM_HEADER_PARAM = 
"atlas.rest-csrf.custom-header";
        public static final String HEADER_DEFAULT = "X-XSRF-HEADER";
        public static final String HEADER_USER_AGENT = "User-Agent";
+       public static final String CSRF_TOKEN = "_csrfToken";
+
 
        private String  headerName = HEADER_DEFAULT;
        private Set<String> methodsToIgnore = null;
        private Set<Pattern> browserUserAgents;
-       
+
        public AtlasCSRFPreventionFilter() {
                try {
                        if (isCSRF_ENABLED){
@@ -167,19 +171,30 @@ public class AtlasCSRFPreventionFilter implements Filter {
                 *             if there is an I/O error
                 */
                void sendError(int code, String message) throws IOException;
-       }       
-         
-       public void handleHttpInteraction(HttpInteraction httpInteraction)
-                       throws IOException, ServletException {
-               if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT))
-                               || 
methodsToIgnore.contains(httpInteraction.getMethod())
-                               || httpInteraction.getHeader(headerName) != 
null) {
+       }
+
+       public void handleHttpInteraction(HttpInteraction httpInteraction) 
throws IOException, ServletException {
+               HttpSession session   = ((ServletFilterHttpInteraction) 
httpInteraction).getSession();
+               String      csrfToken = StringUtils.EMPTY;
+
+               if (session != null) {
+                       csrfToken = (String) session.getAttribute(CSRF_TOKEN);
+               } else {
+                       if (LOG.isDebugEnabled()) {
+                               LOG.debug("Session is null");
+                       }
+               }
+
+               String clientCsrfToken = httpInteraction.getHeader(headerName);
+
+               if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) || 
methodsToIgnore.contains(httpInteraction.getMethod())
+                               || (clientCsrfToken != null && 
clientCsrfToken.equals(csrfToken))) {
                        httpInteraction.proceed();
-               }else {
-                       
httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,"Missing Required 
Header for CSRF Vulnerability Protection");
+               } else {
+                       
httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,"Missing header or 
invalid Header value for CSRF Vulnerability Protection");
                }
        }
-       
+
        public void doFilter(ServletRequest request, ServletResponse response, 
FilterChain chain) throws IOException, ServletException {
         final HttpServletRequest httpRequest = (HttpServletRequest) request;
         final HttpServletResponse httpResponse = (HttpServletResponse) 
response;
@@ -235,6 +250,10 @@ public class AtlasCSRFPreventionFilter implements Filter {
                        chain.doFilter(httpRequest, httpResponse);
                }
 
+               public HttpSession getSession() {
+                       return httpRequest.getSession();
+               }
+
                @Override
                public void sendError(int code, String message) throws 
IOException {
                        JSONObject json = new JSONObject();
diff --git 
a/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java 
b/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java
index d124cd2..46d42ba 100755
--- a/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java
+++ b/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java
@@ -74,6 +74,7 @@ import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.configuration.Configuration;
 import org.apache.commons.configuration.ConfigurationException;
 import org.apache.commons.configuration.PropertiesConfiguration;
+import org.apache.commons.lang.RandomStringUtils;
 import org.apache.commons.lang.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -102,6 +103,7 @@ import javax.ws.rs.core.MediaType;
 import javax.ws.rs.core.Response;
 import java.io.IOException;
 import java.io.InputStream;
+import java.security.SecureRandom;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -114,6 +116,8 @@ import java.util.TimeZone;
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.stream.Collectors;
 
+import static 
org.apache.atlas.web.filters.AtlasCSRFPreventionFilter.CSRF_TOKEN;
+
 
 /**
  * Jersey Resource for admin operations.
@@ -326,7 +330,7 @@ public class AdminResource {
     @GET
     @Path("session")
     @Produces(Servlets.JSON_MEDIA_TYPE)
-    public Response getUserProfile() {
+    public Response getUserProfile(@Context HttpServletRequest request) {
         if (LOG.isDebugEnabled()) {
             LOG.debug("==> AdminResource.getUserProfile()");
         }
@@ -364,9 +368,15 @@ public class AdminResource {
         responseData.put("timezones", TIMEZONE_LIST);
         responseData.put(UI_DATE_TIMEZONE_FORMAT_ENABLED, 
isTimezoneFormatEnabled);
         responseData.put(UI_DATE_FORMAT, uiDateFormat);
-        
responseData.put(AtlasConfiguration.DEBUG_METRICS_ENABLED.getPropertyName(), 
isDebugMetricsEnabled);
-        
responseData.put(AtlasConfiguration.TASKS_USE_ENABLED.getPropertyName(), 
isTasksEnabled);
-        
+
+        String salt = (String) request.getSession().getAttribute(CSRF_TOKEN);
+        if (StringUtils.isEmpty(salt)) {
+            salt = RandomStringUtils.random(20, 0, 0, true, true, null, new 
SecureRandom());
+            request.getSession().setAttribute(CSRF_TOKEN, salt);
+        }
+
+        responseData.put(CSRF_TOKEN, salt);
+
         response = Response.ok(AtlasJson.toV1Json(responseData)).build();
 
         if (LOG.isDebugEnabled()) {
diff --git 
a/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java
 
b/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java
index 954364b..841cfaf 100644
--- 
a/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java
+++ 
b/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java
@@ -23,10 +23,13 @@ import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
 import java.io.IOException;
 import java.io.PrintWriter;
 
+import static 
org.apache.atlas.web.filters.AtlasCSRFPreventionFilter.CSRF_TOKEN;
 import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 
 public class AtlasCSRFPreventionFilterTest {
@@ -61,9 +64,15 @@ public class AtlasCSRFPreventionFilterTest {
                HttpServletRequest mockReq = 
Mockito.mock(HttpServletRequest.class);
                
Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_DEFAULT)).thenReturn("valueUnimportant");
                
Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_USER_AGENT)).thenReturn(userAgent);
+               Mockito.when(mockReq.getMethod()).thenReturn("POST");
+
+               HttpSession session = Mockito.mock(HttpSession.class);
+               
Mockito.when(session.getAttribute(CSRF_TOKEN)).thenReturn("valueUnimportant");
+               Mockito.when(mockReq.getSession()).thenReturn(session);
 
                // Objects to verify interactions based on request
                HttpServletResponse mockRes = 
Mockito.mock(HttpServletResponse.class);
+
                FilterChain mockChain = Mockito.mock(FilterChain.class);
 
                // Object under test
@@ -74,6 +83,28 @@ public class AtlasCSRFPreventionFilterTest {
        }
 
        @Test
+       public void testHeaderPresentDefaultConfig_badRequest() throws 
ServletException, IOException {
+               // CSRF HAS been sent
+               HttpServletRequest mockReq = 
Mockito.mock(HttpServletRequest.class);
+               
Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_DEFAULT)).thenReturn("valueUnimportant");
+               
Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_USER_AGENT)).thenReturn(userAgent);
+               Mockito.when(mockReq.getMethod()).thenReturn("POST");
+
+               // Objects to verify interactions based on request
+               HttpServletResponse mockRes = 
Mockito.mock(HttpServletResponse.class);
+               PrintWriter mockWriter = Mockito.mock(PrintWriter.class);
+               Mockito.when(mockRes.getWriter()).thenReturn(mockWriter);
+
+               FilterChain mockChain = Mockito.mock(FilterChain.class);
+
+               // Object under test
+               AtlasCSRFPreventionFilter filter = new 
AtlasCSRFPreventionFilter();
+               filter.doFilter(mockReq, mockRes, mockChain);
+
+               Mockito.verify(mockChain, never()).doFilter(mockReq, mockRes);
+       }
+
+       @Test
        public void testHeaderPresentCustomHeaderConfig_goodRequest() throws 
ServletException, IOException {
                // CSRF HAS been sent
                HttpServletRequest mockReq = 
Mockito.mock(HttpServletRequest.class);

Reply via email to