diff --git a/cas-client-core/src/main/java/org/jasig/cas/client/util/ErrorRedirectFilter.java b/cas-client-core/src/main/java/org/jasig/cas/client/util/ErrorRedirectFilter.java index 74496c0..dd6f4a2 100644 --- a/cas-client-core/src/main/java/org/jasig/cas/client/util/ErrorRedirectFilter.java +++ b/cas-client-core/src/main/java/org/jasig/cas/client/util/ErrorRedirectFilter.java @@ -24,22 +24,22 @@ import java.util.Enumeration; import java.util.List; import javax.servlet.*; import javax.servlet.http.HttpServletResponse; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Filters that redirects to the supplied url based on an exception. Exceptions and the urls are configured via * init filter name/param values. - *
+ *
* If there is an exact match the filter uses that value. If there's a non-exact match (i.e. inheritance), then the filter * uses the last value that matched. - *+ *
* If there is no match it will redirect to a default error page. The default exception is configured via the "defaultErrorRedirectPage" property. - * + * * @author Scott Battaglia * @version $Revision$ $Date$ * @since 3.1.4 - * */ public final class ErrorRedirectFilter implements Filter { @@ -58,8 +58,8 @@ public final class ErrorRedirectFilter implements Filter { final HttpServletResponse httpResponse = (HttpServletResponse) response; try { filterChain.doFilter(request, response); - } catch (final ServletException e) { - final Throwable t = e.getCause(); + } catch (final Throwable e) { + final Throwable t = extractErrorToCompare(e); ErrorHolder currentMatch = null; for (final ErrorHolder errorHolder : this.errors) { if (errorHolder.exactMatch(t)) { @@ -78,6 +78,22 @@ public final class ErrorRedirectFilter implements Filter { } } + /** + * Determine which error to use for comparison. If there is an {@link Throwable#getCause()} then that will be used. Otherwise, the original throwable is used. + * + * @param throwable the throwable to look for a root cause. + * @return the throwable to use for comparison. MUST NOT BE NULL. + */ + private Throwable extractErrorToCompare(final Throwable throwable) { + final Throwable cause = throwable.getCause(); + + if (cause != null) { + return cause; + } + + return throwable; + } + public void init(final FilterConfig filterConfig) throws ServletException { this.defaultErrorRedirectPage = filterConfig.getInitParameter("defaultErrorRedirectPage"); diff --git a/cas-client-core/src/test/java/org/jasig/cas/client/util/ErrorRedirectFilterTests.java b/cas-client-core/src/test/java/org/jasig/cas/client/util/ErrorRedirectFilterTests.java new file mode 100644 index 0000000..b79daf4 --- /dev/null +++ b/cas-client-core/src/test/java/org/jasig/cas/client/util/ErrorRedirectFilterTests.java @@ -0,0 +1,48 @@ +package org.jasig.cas.client.util; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockFilterConfig; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import javax.servlet.FilterChain; + +import static org.junit.Assert.*; + +public final class ErrorRedirectFilterTests { + + private static final String REDIRECT_URL = "/ise.html"; + + private ErrorRedirectFilter errorRedirectFilter; + + private FilterChain filterChain; + + + @Before + public void setUp() throws Exception { + this.errorRedirectFilter = new ErrorRedirectFilter(); + + final MockFilterConfig filterConfig = new MockFilterConfig(); + filterConfig.addInitParameter(IllegalStateException.class.getName(), REDIRECT_URL); + this.errorRedirectFilter.init(filterConfig); + this.filterChain = new MockFilterChain(); + } + + + @Test + public void noRootCause() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest(); + final MockHttpServletResponse response = new MockHttpServletResponse(); + + // this should be okay as the mock filter chain allows one call + this.errorRedirectFilter.doFilter(request, response, this.filterChain); + + // this will fail as the mock filter chain will throw IllegalStateException + this.errorRedirectFilter.doFilter(request, response, this.filterChain); + + assertEquals(REDIRECT_URL, response.getRedirectedUrl()); + + } +}