# -*- coding: utf-8 - # # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. import re import socket from errno import ENOTCONN from gunicorn._compat import bytes_to_str from gunicorn.http.unreader import SocketUnreader from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body from gunicorn.http.errors import (InvalidHeader, InvalidHeaderName, NoMoreData, InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion, LimitRequestLine, LimitRequestHeaders) from gunicorn.http.errors import InvalidProxyLine, ForbiddenProxyRequest from gunicorn.six import BytesIO from gunicorn._compat import urlsplit MAX_REQUEST_LINE = 8190 MAX_HEADERS = 32768 MAX_HEADERFIELD_SIZE = 8190 HEADER_RE = re.compile("[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") METH_RE = re.compile(r"[A-Z0-9$-_.]{3,20}") VERSION_RE = re.compile(r"HTTP/(\d+).(\d+)") class Message(object): def __init__(self, cfg, unreader): self.cfg = cfg self.unreader = unreader self.version = None self.headers = [] self.trailers = [] self.body = None # set headers limits self.limit_request_fields = cfg.limit_request_fields if (self.limit_request_fields <= 0 or self.limit_request_fields > MAX_HEADERS): self.limit_request_fields = MAX_HEADERS self.limit_request_field_size = cfg.limit_request_field_size if (self.limit_request_field_size < 0 or self.limit_request_field_size > MAX_HEADERFIELD_SIZE): self.limit_request_field_size = MAX_HEADERFIELD_SIZE # set max header buffer size max_header_field_size = self.limit_request_field_size or MAX_HEADERFIELD_SIZE self.max_buffer_headers = self.limit_request_fields * \ (max_header_field_size + 2) + 4 unused = self.parse(self.unreader) self.unreader.unread(unused) self.set_body_reader() def parse(self): raise NotImplementedError() def parse_headers(self, data): headers = [] # Split lines on \r\n keeping the \r\n on each line lines = [bytes_to_str(line) + "\r\n" for line in data.split(b"\r\n")] # Parse headers into key/value pairs paying attention # to continuation lines. while len(lines): if len(headers) >= self.limit_request_fields: raise LimitRequestHeaders("limit request headers fields") # Parse initial header name : value pair. curr = lines.pop(0) header_length = len(curr) if curr.find(":") < 0: raise InvalidHeader(curr.strip()) name, value = curr.split(":", 1) name = name.rstrip(" \t").upper() if HEADER_RE.search(name): raise InvalidHeaderName(name) name, value = name.strip(), [value.lstrip()] # Consume value continuation lines while len(lines) and lines[0].startswith((" ", "\t")): curr = lines.pop(0) header_length += len(curr) if header_length > self.limit_request_field_size > 0: raise LimitRequestHeaders("limit request headers " + "fields size") value.append(curr) value = ''.join(value).rstrip() if header_length > self.limit_request_field_size > 0: raise LimitRequestHeaders("limit request headers fields size") headers.append((name, value)) return headers def set_body_reader(self): chunked = False content_length = None for (name, value) in self.headers: if name == "CONTENT-LENGTH": content_length = value elif name == "TRANSFER-ENCODING": chunked = value.lower() == "chunked" elif name == "SEC-WEBSOCKET-KEY1": content_length = 8 if chunked: self.body = Body(ChunkedReader(self, self.unreader)) elif content_length is not None: try: content_length = int(content_length) except ValueError: raise InvalidHeader("CONTENT-LENGTH", req=self) if content_length < 0: raise InvalidHeader("CONTENT-LENGTH", req=self) self.body = Body(LengthReader(self.unreader, content_length)) else: self.body = Body(EOFReader(self.unreader)) def should_close(self): for (h, v) in self.headers: if h == "CONNECTION": v = v.lower().strip() if v == "close": return True elif v == "keep-alive": return False break return self.version <= (1, 0) class Request(Message): def __init__(self, cfg, unreader, req_number=1): self.method = None self.uri = None self.path = None self.query = None self.fragment = None # get max request line size self.limit_request_line = cfg.limit_request_line if (self.limit_request_line < 0 or self.limit_request_line >= MAX_REQUEST_LINE): self.limit_request_line = MAX_REQUEST_LINE self.req_number = req_number self.proxy_protocol_info = None super(Request, self).__init__(cfg, unreader) def get_data(self, unreader, buf, stop=False): data = unreader.read() if not data: if stop: raise StopIteration() raise NoMoreData(buf.getvalue()) buf.write(data) def parse(self, unreader): buf = BytesIO() self.get_data(unreader, buf, stop=True) # get request line line, rbuf = self.read_line(unreader, buf, self.limit_request_line) # proxy protocol if self.proxy_protocol(bytes_to_str(line)): # get next request line buf = BytesIO() buf.write(rbuf) line, rbuf = self.read_line(unreader, buf, self.limit_request_line) self.parse_request_line(bytes_to_str(line)) buf = BytesIO() buf.write(rbuf) # Headers data = buf.getvalue() idx = data.find(b"\r\n\r\n") done = data[:2] == b"\r\n" while True: idx = data.find(b"\r\n\r\n") done = data[:2] == b"\r\n" if idx < 0 and not done: self.get_data(unreader, buf) data = buf.getvalue() if len(data) > self.max_buffer_headers: raise LimitRequestHeaders("max buffer headers") else: break if done: self.unreader.unread(data[2:]) return b"" self.headers = self.parse_headers(data[:idx]) ret = data[idx + 4:] buf = BytesIO() return ret def read_line(self, unreader, buf, limit=0): data = buf.getvalue() while True: idx = data.find(b"\r\n") if idx >= 0: # check if the request line is too large if idx > limit > 0: raise LimitRequestLine(idx, limit) break elif len(data) - 2 > limit > 0: raise LimitRequestLine(len(data), limit) self.get_data(unreader, buf) data = buf.getvalue() return (data[:idx], # request line, data[idx + 2:]) # residue in the buffer, skip \r\n def proxy_protocol(self, line): """\ Detect, check and parse proxy protocol. :raises: ForbiddenProxyRequest, InvalidProxyLine. :return: True for proxy protocol line else False """ if not self.cfg.proxy_protocol: return False if self.req_number != 1: return False if not line.startswith("PROXY"): return False self.proxy_protocol_access_check() self.parse_proxy_protocol(line) return True def proxy_protocol_access_check(self): # check in allow list if isinstance(self.unreader, SocketUnreader): try: remote_host = self.unreader.sock.getpeername()[0] except socket.error as e: if e.args[0] == ENOTCONN: raise ForbiddenProxyRequest("UNKNOW") raise if ("*" not in self.cfg.proxy_allow_ips and remote_host not in self.cfg.proxy_allow_ips): raise ForbiddenProxyRequest(remote_host) def parse_proxy_protocol(self, line): bits = line.split() if len(bits) != 6: raise InvalidProxyLine(line) # Extract data proto = bits[1] s_addr = bits[2] d_addr = bits[3] # Validation if proto not in ["TCP4", "TCP6"]: raise InvalidProxyLine("protocol '%s' not supported" % proto) if proto == "TCP4": try: socket.inet_pton(socket.AF_INET, s_addr) socket.inet_pton(socket.AF_INET, d_addr) except socket.error: raise InvalidProxyLine(line) elif proto == "TCP6": try: socket.inet_pton(socket.AF_INET6, s_addr) socket.inet_pton(socket.AF_INET6, d_addr) except socket.error: raise InvalidProxyLine(line) try: s_port = int(bits[4]) d_port = int(bits[5]) except ValueError: raise InvalidProxyLine("invalid port %s" % line) if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)): raise InvalidProxyLine("invalid port %s" % line) # Set data self.proxy_protocol_info = { "proxy_protocol": proto, "client_addr": s_addr, "client_port": s_port, "proxy_addr": d_addr, "proxy_port": d_port } def parse_request_line(self, line): bits = line.split(None, 2) if len(bits) != 3: raise InvalidRequestLine(line) # Method if not METH_RE.match(bits[0]): raise InvalidRequestMethod(bits[0]) self.method = bits[0].upper() # URI # When the path starts with //, urlsplit considers it as a # relative uri while the RDF says it shouldnt # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 # considers it as an absolute url. # fix issue #297 if bits[1].startswith("//"): self.uri = bits[1][1:] else: self.uri = bits[1] try: parts = urlsplit(self.uri) except ValueError: raise InvalidRequestLine(line) self.path = parts.path or "" self.query = parts.query or "" self.fragment = parts.fragment or "" # Version match = VERSION_RE.match(bits[2]) if match is None: raise InvalidHTTPVersion(bits[2]) self.version = (int(match.group(1)), int(match.group(2))) def set_body_reader(self): super(Request, self).set_body_reader() if isinstance(self.body.reader, EOFReader): self.body = Body(LengthReader(self.unreader, 0))