blob: ea1c353b48b95d114ff24defc29dc12a7478a29d [file] [log] [blame]
/*
* Copyright (C) 2012 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.inject.servlet;
import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.net.UrlEscapers;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
/**
* Some servlet utility methods.
*
* @author ntang@google.com (Michael Tang)
*/
final class ServletUtils {
private static final Splitter SLASH_SPLITTER = Splitter.on('/');
private static final Joiner SLASH_JOINER = Joiner.on('/');
private ServletUtils() {
// private to prevent instantiation.
}
/**
* Gets the context path relative path of the URI. Returns the path of the resource relative to
* the context path for a request's URI, or null if no path can be extracted.
*
* <p>Also performs url decoding and normalization of the path.
*/
// @Nullable
static String getContextRelativePath(
// @Nullable
final HttpServletRequest request) {
if (request != null) {
String contextPath = request.getContextPath();
String requestURI = request.getRequestURI();
if (contextPath.length() < requestURI.length()) {
String suffix = requestURI.substring(contextPath.length());
return normalizePath(suffix);
} else if (requestURI.trim().length() > 0 && contextPath.length() == requestURI.length()) {
return "/";
}
}
return null;
}
/** Normalizes a path by unescaping all safe, percent encoded characters. */
static String normalizePath(String path) {
StringBuilder sb = new StringBuilder(path.length());
int queryStart = path.indexOf('?');
String query = null;
if (queryStart != -1) {
query = path.substring(queryStart);
path = path.substring(0, queryStart);
}
// Normalize the path. we need to decode path segments, normalize and rejoin in order to
// 1. decode and normalize safe percent escaped characters. e.g. %70 -> 'p'
// 2. decode and interpret dangerous character sequences. e.g. /%2E/ -> '/./' -> '/'
// 3. preserve dangerous encoded characters. e.g. '/%2F/' -> '///' -> '/%2F'
List<String> segments = new ArrayList<>();
for (String segment : SLASH_SPLITTER.split(path)) {
// This decodes all non-special characters from the path segment. so if someone passes
// /%2E/foo we will normalize it to /./foo and then /foo
String normalized =
UrlEscapers.urlPathSegmentEscaper().escape(lenientDecode(segment, UTF_8, false));
if (".".equals(normalized)) {
// skip
} else if ("..".equals(normalized)) {
if (segments.size() > 1) {
segments.remove(segments.size() - 1);
}
} else {
segments.add(normalized);
}
}
SLASH_JOINER.appendTo(sb, segments);
if (query != null) {
sb.append(query);
}
return sb.toString();
}
/**
* Percent-decodes a US-ASCII string into a Unicode string. The specified encoding is used to
* determine what characters are represented by any consecutive sequences of the form
* "%<i>XX</i>". This is the lenient kind of decoding that will simply ignore and copy as-is any
* "%XX" sequence that is invalid (for example, "%HH").
*
* @param string a percent-encoded US-ASCII string
* @param encoding a character encoding
* @param decodePlus boolean to indicate whether to decode '+' as ' '
* @return a Unicode string
*/
private static String lenientDecode(String string, Charset encoding, boolean decodePlus) {
checkNotNull(string);
checkNotNull(encoding);
if (decodePlus) {
string = string.replace('+', ' ');
}
int firstPercentPos = string.indexOf('%');
if (firstPercentPos < 0) {
return string;
}
ByteAccumulator accumulator = new ByteAccumulator(string.length(), encoding);
StringBuilder builder = new StringBuilder(string.length());
if (firstPercentPos > 0) {
builder.append(string, 0, firstPercentPos);
}
for (int srcPos = firstPercentPos; srcPos < string.length(); srcPos++) {
char c = string.charAt(srcPos);
if (c < 0x80) { // ASCII
boolean processed = false;
if (c == '%' && string.length() >= srcPos + 3) {
String hex = string.substring(srcPos + 1, srcPos + 3);
try {
int encoded = Integer.parseInt(hex, 16);
if (encoded >= 0) {
accumulator.append((byte) encoded);
srcPos += 2;
processed = true;
}
} catch (NumberFormatException ignore) {
// Expected case (badly formatted % group)
}
}
if (!processed) {
if (accumulator.isEmpty()) {
// We're not accumulating elements of a multibyte encoded
// char, so just toss it right into the result string.
builder.append(c);
} else {
accumulator.append((byte) c);
}
}
} else { // Non-ASCII
// A non-ASCII char marks the end of a multi-char encoding sequence,
// if one is in progress.
accumulator.dumpTo(builder);
builder.append(c);
}
}
accumulator.dumpTo(builder);
return builder.toString();
}
/** Accumulates byte sequences while decoding strings, and encodes them into a StringBuilder. */
private static class ByteAccumulator {
private byte[] bytes;
private int length;
private final Charset encoding;
ByteAccumulator(int capacity, Charset encoding) {
this.bytes = new byte[Math.min(16, capacity)];
this.encoding = encoding;
}
void append(byte b) {
ensureCapacity(length + 1);
bytes[length++] = b;
}
void dumpTo(StringBuilder dest) {
if (length != 0) {
dest.append(new String(bytes, 0, length, encoding));
length = 0;
}
}
boolean isEmpty() {
return length == 0;
}
private void ensureCapacity(int minCapacity) {
int cap = bytes.length;
if (cap >= minCapacity) {
return;
}
int newCapacity = cap + (cap >> 1); // *1.5
if (newCapacity < minCapacity) {
// we are close to overflowing, grow by smaller steps
newCapacity = minCapacity;
}
// in other cases, we will naturally throw an OOM from here
bytes = Arrays.copyOf(bytes, newCapacity);
}
}
}