blob: be20d586b25daa2e8613d58b57876bb721e0800f [file] [log] [blame]
//
// ========================================================================
// Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.servlets;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionActivationListener;
import javax.servlet.http.HttpSessionBindingEvent;
import javax.servlet.http.HttpSessionBindingListener;
import javax.servlet.http.HttpSessionEvent;
import org.eclipse.jetty.continuation.Continuation;
import org.eclipse.jetty.continuation.ContinuationListener;
import org.eclipse.jetty.continuation.ContinuationSupport;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.thread.Timeout;
/**
* Denial of Service filter
* <p/>
* <p>
* This filter is useful for limiting
* exposure to abuse from request flooding, whether malicious, or as a result of
* a misconfigured client.
* <p>
* The filter keeps track of the number of requests from a connection per
* second. If a limit is exceeded, the request is either rejected, delayed, or
* throttled.
* <p>
* When a request is throttled, it is placed in a priority queue. Priority is
* given first to authenticated users and users with an HttpSession, then
* connections which can be identified by their IP addresses. Connections with
* no way to identify them are given lowest priority.
* <p>
* The {@link #extractUserId(ServletRequest request)} function should be
* implemented, in order to uniquely identify authenticated users.
* <p>
* The following init parameters control the behavior of the filter:<dl>
* <p/>
* <dt>maxRequestsPerSec</dt>
* <dd>the maximum number of requests from a connection per
* second. Requests in excess of this are first delayed,
* then throttled.</dd>
* <p/>
* <dt>delayMs</dt>
* <dd>is the delay given to all requests over the rate limit,
* before they are considered at all. -1 means just reject request,
* 0 means no delay, otherwise it is the delay.</dd>
* <p/>
* <dt>maxWaitMs</dt>
* <dd>how long to blocking wait for the throttle semaphore.</dd>
* <p/>
* <dt>throttledRequests</dt>
* <dd>is the number of requests over the rate limit able to be
* considered at once.</dd>
* <p/>
* <dt>throttleMs</dt>
* <dd>how long to async wait for semaphore.</dd>
* <p/>
* <dt>maxRequestMs</dt>
* <dd>how long to allow this request to run.</dd>
* <p/>
* <dt>maxIdleTrackerMs</dt>
* <dd>how long to keep track of request rates for a connection,
* before deciding that the user has gone away, and discarding it</dd>
* <p/>
* <dt>insertHeaders</dt>
* <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
* <p/>
* <dt>trackSessions</dt>
* <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
* <p/>
* <dt>remotePort</dt>
* <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
* <p/>
* <dt>ipWhitelist</dt>
* <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
* <p/>
* <dt>managedAttr</dt>
* <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the
* filter name as the attribute name. This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to
* manage the configuration of the filter.</dd>
* </dl>
* </p>
*/
public class DoSFilter implements Filter
{
private static final Logger LOG = Log.getLogger(DoSFilter.class);
private static final String IPv4_GROUP = "(\\d{1,3})";
private static final Pattern IPv4_PATTERN = Pattern.compile(IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP);
private static final String IPv6_GROUP = "(\\p{XDigit}{1,4})";
private static final Pattern IPv6_PATTERN = Pattern.compile(IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP);
private static final Pattern CIDR_PATTERN = Pattern.compile("([^/]+)/(\\d+)");
private static final String __TRACKER = "DoSFilter.Tracker";
private static final String __THROTTLED = "DoSFilter.Throttled";
private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
private static final int __DEFAULT_DELAY_MS = 100;
private static final int __DEFAULT_THROTTLE = 5;
private static final int __DEFAULT_MAX_WAIT_MS = 50;
private static final long __DEFAULT_THROTTLE_MS = 30000L;
private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
static final String DELAY_MS_INIT_PARAM = "delayMs";
static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
static final String MAX_WAIT_INIT_PARAM = "maxWaitMs";
static final String THROTTLE_MS_INIT_PARAM = "throttleMs";
static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
static final String REMOTE_PORT_INIT_PARAM = "remotePort";
static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
static final String ENABLED_INIT_PARAM = "enabled";
private static final int USER_AUTH = 2;
private static final int USER_SESSION = 2;
private static final int USER_IP = 1;
private static final int USER_UNKNOWN = 0;
private ServletContext _context;
private volatile long _delayMs;
private volatile long _throttleMs;
private volatile long _maxWaitMs;
private volatile long _maxRequestMs;
private volatile long _maxIdleTrackerMs;
private volatile boolean _insertHeaders;
private volatile boolean _trackSessions;
private volatile boolean _remotePort;
private volatile boolean _enabled;
private Semaphore _passes;
private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec;
private Queue<Continuation>[] _queue;
private ContinuationListener[] _listeners;
private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<String, RateTracker>();
private final List<String> _whitelist = new CopyOnWriteArrayList<String>();
private final Timeout _requestTimeoutQ = new Timeout();
private final Timeout _trackerTimeoutQ = new Timeout();
private Thread _timerThread;
private volatile boolean _running;
public void init(FilterConfig filterConfig)
{
_context = filterConfig.getServletContext();
_queue = new Queue[getMaxPriority() + 1];
_listeners = new ContinuationListener[getMaxPriority() + 1];
for (int p = 0; p < _queue.length; p++)
{
_queue[p] = new ConcurrentLinkedQueue<Continuation>();
final int priority = p;
_listeners[p] = new ContinuationListener()
{
public void onComplete(Continuation continuation)
{
}
public void onTimeout(Continuation continuation)
{
_queue[priority].remove(continuation);
}
};
}
_rateTrackers.clear();
int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM);
if (parameter != null)
maxRequests = Integer.parseInt(parameter);
setMaxRequestsPerSec(maxRequests);
long delay = __DEFAULT_DELAY_MS;
parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM);
if (parameter != null)
delay = Long.parseLong(parameter);
setDelayMs(delay);
int throttledRequests = __DEFAULT_THROTTLE;
parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM);
if (parameter != null)
throttledRequests = Integer.parseInt(parameter);
setThrottledRequests(throttledRequests);
long maxWait = __DEFAULT_MAX_WAIT_MS;
parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM);
if (parameter != null)
maxWait = Long.parseLong(parameter);
setMaxWaitMs(maxWait);
long throttle = __DEFAULT_THROTTLE_MS;
parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM);
if (parameter != null)
throttle = Long.parseLong(parameter);
setThrottleMs(throttle);
long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM);
if (parameter != null)
maxRequestMs = Long.parseLong(parameter);
setMaxRequestMs(maxRequestMs);
long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM);
if (parameter != null)
maxIdleTrackerMs = Long.parseLong(parameter);
setMaxIdleTrackerMs(maxIdleTrackerMs);
String whiteList = "";
parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
if (parameter != null)
whiteList = parameter;
setWhitelist(whiteList);
parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
setTrackSessions(parameter == null || Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
setRemotePort(parameter != null && Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM);
setEnabled(parameter == null || Boolean.parseBoolean(parameter));
_requestTimeoutQ.setNow();
_requestTimeoutQ.setDuration(_maxRequestMs);
_trackerTimeoutQ.setNow();
_trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
_running = true;
_timerThread = (new Thread()
{
public void run()
{
try
{
while (_running)
{
long now = _requestTimeoutQ.setNow();
_requestTimeoutQ.tick();
_trackerTimeoutQ.setNow(now);
_trackerTimeoutQ.tick();
try
{
Thread.sleep(100);
}
catch (InterruptedException e)
{
LOG.ignore(e);
}
}
}
finally
{
LOG.debug("DoSFilter timer exited");
}
}
});
_timerThread.start();
if (_context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
_context.setAttribute(filterConfig.getFilterName(), this);
}
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException
{
doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain);
}
protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException
{
if (!isEnabled())
{
filterChain.doFilter(request, response);
return;
}
final long now = _requestTimeoutQ.getNow();
// Look for the rate tracker for this request
RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
if (tracker == null)
{
// This is the first time we have seen this request.
// get a rate tracker associated with this request, and record one hit
tracker = getRateTracker(request);
// Calculate the rate and check it is over the allowed limit
final boolean overRateLimit = tracker.isRateExceeded(now);
// pass it through if we are not currently over the rate limit
if (!overRateLimit)
{
doFilterChain(filterChain, request, response);
return;
}
// We are over the limit.
// So either reject it, delay it or throttle it
long delayMs = getDelayMs();
boolean insertHeaders = isInsertHeaders();
switch ((int)delayMs)
{
case -1:
{
// Reject this request
LOG.warn("DOS ALERT: Request rejected ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "unavailable");
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
return;
}
case 0:
{
// fall through to throttle code
LOG.warn("DOS ALERT: Request throttled ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
request.setAttribute(__TRACKER, tracker);
break;
}
default:
{
// insert a delay before throttling the request
LOG.warn("DOS ALERT: Request delayed="+delayMs+"ms ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "delayed");
Continuation continuation = ContinuationSupport.getContinuation(request);
request.setAttribute(__TRACKER, tracker);
if (delayMs > 0)
continuation.setTimeout(delayMs);
continuation.suspend();
return;
}
}
}
// Throttle the request
boolean accepted = false;
try
{
// check if we can afford to accept another request at this time
accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
if (!accepted)
{
// we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
final Continuation continuation = ContinuationSupport.getContinuation(request);
Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
long throttleMs = getThrottleMs();
if (throttled != Boolean.TRUE && throttleMs > 0)
{
int priority = getPriority(request, tracker);
request.setAttribute(__THROTTLED, Boolean.TRUE);
if (isInsertHeaders())
response.addHeader("DoSFilter", "throttled");
if (throttleMs > 0)
continuation.setTimeout(throttleMs);
continuation.suspend();
continuation.addContinuationListener(_listeners[priority]);
_queue[priority].add(continuation);
return;
}
// else were we resumed?
else if (request.getAttribute("javax.servlet.resumed") == Boolean.TRUE)
{
// we were resumed and somebody stole our pass, so we wait for the next one.
_passes.acquire();
accepted = true;
}
}
// if we were accepted (either immediately or after throttle)
if (accepted)
// call the chain
doFilterChain(filterChain, request, response);
else
{
// fail the request
if (isInsertHeaders())
response.addHeader("DoSFilter", "unavailable");
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
}
}
catch (InterruptedException e)
{
_context.log("DoS", e);
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
}
finally
{
if (accepted)
{
// wake up the next highest priority request.
for (int p = _queue.length; p-- > 0; )
{
Continuation continuation = _queue[p].poll();
if (continuation != null && continuation.isSuspended())
{
continuation.resume();
break;
}
}
_passes.release();
}
}
}
protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException
{
final Thread thread = Thread.currentThread();
final Timeout.Task requestTimeout = new Timeout.Task()
{
public void expired()
{
closeConnection(request, response, thread);
}
};
try
{
_requestTimeoutQ.schedule(requestTimeout);
chain.doFilter(request, response);
}
finally
{
requestTimeout.cancel();
}
}
/**
* Takes drastic measures to return this response and stop this thread.
* Due to the way the connection is interrupted, may return mixed up headers.
*
* @param request current request
* @param response current response, which must be stopped
* @param thread the handling thread
*/
protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
{
// take drastic measures to return this response and stop this thread.
if (!response.isCommitted())
{
response.setHeader("Connection", "close");
}
try
{
try
{
response.getWriter().close();
}
catch (IllegalStateException e)
{
response.getOutputStream().close();
}
}
catch (IOException e)
{
LOG.warn(e);
}
// interrupt the handling thread
thread.interrupt();
}
/**
* Get priority for this request, based on user type
*
* @param request the current request
* @param tracker the rate tracker for this request
* @return the priority for this request
*/
protected int getPriority(HttpServletRequest request, RateTracker tracker)
{
if (extractUserId(request) != null)
return USER_AUTH;
if (tracker != null)
return tracker.getType();
return USER_UNKNOWN;
}
/**
* @return the maximum priority that we can assign to a request
*/
protected int getMaxPriority()
{
return USER_AUTH;
}
/**
* Return a request rate tracker associated with this connection; keeps
* track of this connection's request rate. If this is not the first request
* from this connection, return the existing object with the stored stats.
* If it is the first request, then create a new request tracker.
* <p/>
* Assumes that each connection has an identifying characteristic, and goes
* through them in order, taking the first that matches: user id (logged
* in), session id, client IP address. Unidentifiable connections are lumped
* into one.
* <p/>
* When a session expires, its rate tracker is automatically deleted.
*
* @param request the current request
* @return the request rate tracker for the current connection
*/
public RateTracker getRateTracker(ServletRequest request)
{
HttpSession session = ((HttpServletRequest)request).getSession(false);
String loadId = extractUserId(request);
final int type;
if (loadId != null)
{
type = USER_AUTH;
}
else
{
if (_trackSessions && session != null && !session.isNew())
{
loadId = session.getId();
type = USER_SESSION;
}
else
{
loadId = _remotePort ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr();
type = USER_IP;
}
}
RateTracker tracker = _rateTrackers.get(loadId);
if (tracker == null)
{
boolean allowed = checkWhitelist(_whitelist, request.getRemoteAddr());
tracker = allowed ? new FixedRateTracker(loadId, type, _maxRequestsPerSec)
: new RateTracker(loadId, type, _maxRequestsPerSec);
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
if (existing != null)
tracker = existing;
if (type == USER_IP)
{
// USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
_trackerTimeoutQ.schedule(tracker);
}
else if (session != null)
{
// USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
session.setAttribute(__TRACKER, tracker);
}
}
return tracker;
}
protected boolean checkWhitelist(List<String> whitelist, String candidate)
{
for (String address : whitelist)
{
if (address.contains("/"))
{
if (subnetMatch(address, candidate))
return true;
}
else
{
if (address.equals(candidate))
return true;
}
}
return false;
}
protected boolean subnetMatch(String subnetAddress, String address)
{
Matcher cidrMatcher = CIDR_PATTERN.matcher(subnetAddress);
if (!cidrMatcher.matches())
return false;
String subnet = cidrMatcher.group(1);
int prefix;
try
{
prefix = Integer.parseInt(cidrMatcher.group(2));
}
catch (NumberFormatException x)
{
LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
return false;
}
byte[] subnetBytes = addressToBytes(subnet);
if (subnetBytes == null)
{
LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
return false;
}
byte[] addressBytes = addressToBytes(address);
if (addressBytes == null)
{
LOG.info("Ignoring malformed remote address {}", address);
return false;
}
// Comparing IPv4 with IPv6 ?
int length = subnetBytes.length;
if (length != addressBytes.length)
return false;
byte[] mask = prefixToBytes(prefix, length);
for (int i = 0; i < length; ++i)
{
if ((subnetBytes[i] & mask[i]) != (addressBytes[i] & mask[i]))
return false;
}
return true;
}
private byte[] addressToBytes(String address)
{
Matcher ipv4Matcher = IPv4_PATTERN.matcher(address);
if (ipv4Matcher.matches())
{
byte[] result = new byte[4];
for (int i = 0; i < result.length; ++i)
result[i] = Integer.valueOf(ipv4Matcher.group(i + 1)).byteValue();
return result;
}
else
{
Matcher ipv6Matcher = IPv6_PATTERN.matcher(address);
if (ipv6Matcher.matches())
{
byte[] result = new byte[16];
for (int i = 0; i < result.length; i += 2)
{
int word = Integer.valueOf(ipv6Matcher.group(i / 2 + 1), 16);
result[i] = (byte)((word & 0xFF00) >>> 8);
result[i + 1] = (byte)(word & 0xFF);
}
return result;
}
}
return null;
}
private byte[] prefixToBytes(int prefix, int length)
{
byte[] result = new byte[length];
int index = 0;
while (prefix / 8 > 0)
{
result[index] = -1;
prefix -= 8;
++index;
}
// Sets the _prefix_ most significant bits to 1
result[index] = (byte)~((1 << (8 - prefix)) - 1);
return result;
}
public void destroy()
{
LOG.debug("Destroy {}",this);
_running = false;
_timerThread.interrupt();
_requestTimeoutQ.cancelAll();
_trackerTimeoutQ.cancelAll();
_rateTrackers.clear();
_whitelist.clear();
}
/**
* Returns the user id, used to track this connection.
* This SHOULD be overridden by subclasses.
*
* @param request the current request
* @return a unique user id, if logged in; otherwise null.
*/
protected String extractUserId(ServletRequest request)
{
return null;
}
/**
* Get maximum number of requests from a connection per
* second. Requests in excess of this are first delayed,
* then throttled.
*
* @return maximum number of requests
*/
public int getMaxRequestsPerSec()
{
return _maxRequestsPerSec;
}
/**
* Get maximum number of requests from a connection per
* second. Requests in excess of this are first delayed,
* then throttled.
*
* @param value maximum number of requests
*/
public void setMaxRequestsPerSec(int value)
{
_maxRequestsPerSec = value;
}
/**
* Get delay (in milliseconds) that is applied to all requests
* over the rate limit, before they are considered at all.
*/
public long getDelayMs()
{
return _delayMs;
}
/**
* Set delay (in milliseconds) that is applied to all requests
* over the rate limit, before they are considered at all.
*
* @param value delay (in milliseconds), 0 - no delay, -1 - reject request
*/
public void setDelayMs(long value)
{
_delayMs = value;
}
/**
* Get maximum amount of time (in milliseconds) the filter will
* blocking wait for the throttle semaphore.
*
* @return maximum wait time
*/
public long getMaxWaitMs()
{
return _maxWaitMs;
}
/**
* Set maximum amount of time (in milliseconds) the filter will
* blocking wait for the throttle semaphore.
*
* @param value maximum wait time
*/
public void setMaxWaitMs(long value)
{
_maxWaitMs = value;
}
/**
* Get number of requests over the rate limit able to be
* considered at once.
*
* @return number of requests
*/
public int getThrottledRequests()
{
return _throttledRequests;
}
/**
* Set number of requests over the rate limit able to be
* considered at once.
*
* @param value number of requests
*/
public void setThrottledRequests(int value)
{
int permits = _passes == null ? 0 : _passes.availablePermits();
_passes = new Semaphore((value - _throttledRequests + permits), true);
_throttledRequests = value;
}
/**
* Get amount of time (in milliseconds) to async wait for semaphore.
*
* @return wait time
*/
public long getThrottleMs()
{
return _throttleMs;
}
/**
* Set amount of time (in milliseconds) to async wait for semaphore.
*
* @param value wait time
*/
public void setThrottleMs(long value)
{
_throttleMs = value;
}
/**
* Get maximum amount of time (in milliseconds) to allow
* the request to process.
*
* @return maximum processing time
*/
public long getMaxRequestMs()
{
return _maxRequestMs;
}
/**
* Set maximum amount of time (in milliseconds) to allow
* the request to process.
*
* @param value maximum processing time
*/
public void setMaxRequestMs(long value)
{
_maxRequestMs = value;
}
/**
* Get maximum amount of time (in milliseconds) to keep track
* of request rates for a connection, before deciding that
* the user has gone away, and discarding it.
*
* @return maximum tracking time
*/
public long getMaxIdleTrackerMs()
{
return _maxIdleTrackerMs;
}
/**
* Set maximum amount of time (in milliseconds) to keep track
* of request rates for a connection, before deciding that
* the user has gone away, and discarding it.
*
* @param value maximum tracking time
*/
public void setMaxIdleTrackerMs(long value)
{
_maxIdleTrackerMs = value;
}
/**
* Check flag to insert the DoSFilter headers into the response.
*
* @return value of the flag
*/
public boolean isInsertHeaders()
{
return _insertHeaders;
}
/**
* Set flag to insert the DoSFilter headers into the response.
*
* @param value value of the flag
*/
public void setInsertHeaders(boolean value)
{
_insertHeaders = value;
}
/**
* Get flag to have usage rate tracked by session if a session exists.
*
* @return value of the flag
*/
public boolean isTrackSessions()
{
return _trackSessions;
}
/**
* Set flag to have usage rate tracked by session if a session exists.
*
* @param value value of the flag
*/
public void setTrackSessions(boolean value)
{
_trackSessions = value;
}
/**
* Get flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
*
* @return value of the flag
*/
public boolean isRemotePort()
{
return _remotePort;
}
/**
* Set flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
*
* @param value value of the flag
*/
public void setRemotePort(boolean value)
{
_remotePort = value;
}
/**
* @return whether this filter is enabled
*/
public boolean isEnabled()
{
return _enabled;
}
/**
* @param enabled whether this filter is enabled
*/
public void setEnabled(boolean enabled)
{
_enabled = enabled;
}
/**
* Get a list of IP addresses that will not be rate limited.
*
* @return comma-separated whitelist
*/
public String getWhitelist()
{
StringBuilder result = new StringBuilder();
for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();)
{
String address = iterator.next();
result.append(address);
if (iterator.hasNext())
result.append(",");
}
return result.toString();
}
/**
* Set a list of IP addresses that will not be rate limited.
*
* @param value comma-separated whitelist
*/
public void setWhitelist(String value)
{
List<String> result = new ArrayList<String>();
for (String address : value.split(","))
addWhitelistAddress(result, address);
_whitelist.clear();
_whitelist.addAll(result);
LOG.debug("Whitelisted IP addresses: {}", result);
}
public void clearWhitelist()
{
_whitelist.clear();
}
public boolean addWhitelistAddress(String address)
{
return addWhitelistAddress(_whitelist, address);
}
private boolean addWhitelistAddress(List<String> list, String address)
{
address = address.trim();
return address.length() > 0 && list.add(address);
}
public boolean removeWhitelistAddress(String address)
{
return _whitelist.remove(address);
}
/**
* A RateTracker is associated with a connection, and stores request rate
* data.
*/
class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener, Serializable
{
private static final long serialVersionUID = 3534663738034577872L;
transient protected final String _id;
transient protected final int _type;
transient protected final long[] _timestamps;
transient protected int _next;
public RateTracker(String id, int type, int maxRequestsPerSecond)
{
_id = id;
_type = type;
_timestamps = new long[maxRequestsPerSecond];
_next = 0;
}
/**
* @return the current calculated request rate over the last second
*/
public boolean isRateExceeded(long now)
{
final long last;
synchronized (this)
{
last = _timestamps[_next];
_timestamps[_next] = now;
_next = (_next + 1) % _timestamps.length;
}
return last != 0 && (now - last) < 1000L;
}
public String getId()
{
return _id;
}
public int getType()
{
return _type;
}
public void valueBound(HttpSessionBindingEvent event)
{
if (LOG.isDebugEnabled())
LOG.debug("Value bound: {}", getId());
}
public void valueUnbound(HttpSessionBindingEvent event)
{
//take the tracker out of the list of trackers
_rateTrackers.remove(_id);
if (LOG.isDebugEnabled())
LOG.debug("Tracker removed: {}", getId());
}
public void sessionWillPassivate(HttpSessionEvent se)
{
//take the tracker of the list of trackers (if its still there)
//and ensure that we take ourselves out of the session so we are not saved
_rateTrackers.remove(_id);
se.getSession().removeAttribute(__TRACKER);
if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId());
}
public void sessionDidActivate(HttpSessionEvent se)
{
LOG.warn("Unexpected session activation");
}
public void expired()
{
long now = _trackerTimeoutQ.getNow();
int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
long last = _timestamps[latestIndex];
boolean hasRecentRequest = last != 0 && (now - last) < 1000L;
if (hasRecentRequest)
reschedule();
else
_rateTrackers.remove(_id);
}
@Override
public String toString()
{
return "RateTracker/" + _id + "/" + _type;
}
}
class FixedRateTracker extends RateTracker
{
public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
{
super(id, type, numRecentRequestsTracked);
}
@Override
public boolean isRateExceeded(long now)
{
// rate limit is never exceeded, but we keep track of the request timestamps
// so that we know whether there was recent activity on this tracker
// and whether it should be expired
synchronized (this)
{
_timestamps[_next] = now;
_next = (_next + 1) % _timestamps.length;
}
return false;
}
@Override
public String toString()
{
return "Fixed" + super.toString();
}
}
}