From c30abd514d12bd92e6c0aee0ebb8eb2860e0b5f4 Mon Sep 17 00:00:00 2001 From: Phus Lu Date: Fri, 17 Oct 2014 18:01:09 +0800 Subject: [PATCH] timeout more quickly --- local/proxylib.py | 79 +++++++++++++++-------------------------------- 1 file changed, 25 insertions(+), 54 deletions(-) diff --git a/local/proxylib.py b/local/proxylib.py index 8dc4f9d6d..63fdb17d9 100644 --- a/local/proxylib.py +++ b/local/proxylib.py @@ -38,7 +38,6 @@ gevent = sys.modules.get('gevent') or logging.warn('please enable gevent.') -NetWorkError = (socket.error, ssl.SSLError, OpenSSL.SSL.Error, OSError) try: @@ -809,7 +808,7 @@ def __io_copy(dest, source, timeout): dest.sendall(data) except socket.timeout: pass - except NetWorkError as e: + except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e: if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.ENOTCONN, errno.EPIPE): raise if e.args[0] in (errno.EBADF,): @@ -824,43 +823,6 @@ def __io_copy(dest, source, timeout): __io_copy(local, remote, timeout) -def deprecated_forward_socket(local, remote, timeout, bufsize): - """deprecated forward socket""" - try: - tick = 1 - timecount = timeout - while 1: - timecount -= tick - if timecount <= 0: - break - (ins, _, errors) = select.select([local, remote], [], [local, remote], tick) - if errors: - break - for sock in ins: - data = sock.recv(bufsize) - if not data: - break - if sock is remote: - local.sendall(data) - timecount = timeout - else: - remote.sendall(data) - timecount = timeout - except socket.timeout: - pass - except NetWorkError as e: - if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.ENOTCONN, errno.EPIPE): - raise - if e.args[0] in (errno.EBADF,): - return - finally: - for sock in (remote, local): - try: - sock.close() - except StandardError: - pass - - class LocalProxyServer(SocketServer.ThreadingTCPServer): """Local Proxy Server""" request_queue_size = 1024 @@ -876,7 +838,7 @@ def close_request(self, request): def finish_request(self, request, client_address): try: self.RequestHandlerClass(request, client_address, self) - except NetWorkError as e: + except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e: if e[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE): raise @@ -884,7 +846,7 @@ def handle_error(self, *args): """make ThreadingTCPServer happy""" exc_info = sys.exc_info() error = exc_info and len(exc_info) and exc_info[1] - if isinstance(error, NetWorkError) and len(error.args) > 1 and 'bad write retry' in error.args[1]: + if isinstance(error, (socket.error, ssl.SSLError, OpenSSL.SSL.Error)) and len(error.args) > 1 and 'bad write retry' in error.args[1]: exc_info = error = None else: del exc_info, error @@ -1003,7 +965,7 @@ def handle(self, handler, do_ssl_handshake=True): if e[0] == -1 and 'Unexpected EOF' in e[1]: return raise - except NetWorkError as e: + except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e: if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET): logging.exception('ssl.wrap_socket(connection=%r) failed: %s', handler.connection, e) return @@ -1027,7 +989,7 @@ def handle(self, handler, do_ssl_handshake=True): handler.send_error(400) handler.wfile.close() return - except NetWorkError as e: + except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e: if e.args[0] in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE): handler.close_connection = 1 return @@ -1035,7 +997,7 @@ def handle(self, handler, do_ssl_handshake=True): raise try: handler.do_METHOD() - except NetWorkError as e: + except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e: if e.args[0] not in (errno.ECONNABORTED, errno.ETIMEDOUT, errno.EPIPE): raise @@ -1460,7 +1422,7 @@ def finish(self): """make python2 BaseHTTPRequestHandler happy""" try: BaseHTTPServer.BaseHTTPRequestHandler.finish(self) - except NetWorkError as e: + except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e: if e[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE): raise @@ -1677,7 +1639,7 @@ def close_connection(count, queobj, first_tcp_time): for _ in range(count): sock = queobj.get() tcp_time_threshold = min(1, 1.3 * first_tcp_time) - if sock and not isinstance(sock, Exception): + if sock and hasattr(sock, 'getpeername'): if cache_key and (sock.getpeername()[0] in self.iplist_predefined or self.tcp_connection_cachesock) and sock.tcp_time < tcp_time_threshold: cache_queue = self.tcp_connection_cache[cache_key] if cache_queue.qsize() < 8: @@ -1736,13 +1698,13 @@ def reorg_ipaddrs(): thread.start_new_thread(create_connection, (addr, timeout, queobj)) for i in range(len(addrs)): sock = queobj.get() - if not isinstance(sock, Exception): - thread.start_new_thread(close_connection, (len(addrs)-i-1, queobj, getattr(sock, 'tcp_time') or self.tcp_connection_time[sock.getpeername()])) + if hasattr(sock, 'getpeername'): + spawn_later(0.01, close_connection, len(addrs)-i-1, queobj, getattr(sock, 'tcp_time') or self.tcp_connection_time[sock.getpeername()]) return sock elif i == 0: # only output first error logging.warning('create_tcp_connection to %r with %s return %r, try again.', hostname, addrs, sock) - if isinstance(sock, Exception): + if not hasattr(sock, 'getpeername'): raise sock def create_ssl_connection(self, hostname, port, timeout, **kwargs): @@ -1837,6 +1799,12 @@ def create_connection(ipaddr, timeout, queobj): def create_connection_withopenssl(ipaddr, timeout, queobj): sock = None ssl_sock = None + timer = None + NetworkError = (socket.error, OpenSSL.SSL.Error, OSError) + if gevent: + NetworkError += (gevent.Timeout,) + timer = gevent.Timeout(timeout) + timer.start() try: # create a ipv4/ipv6 socket object sock = socket.socket(socket.AF_INET if ':' not in ipaddr[0] else socket.AF_INET6) @@ -1899,7 +1867,7 @@ def create_connection_withopenssl(ipaddr, timeout, queobj): response.close() # put ssl socket object to output queobj queobj.put(ssl_sock) - except (socket.error, OpenSSL.SSL.Error, OSError) as e: + except NetworkError as e: # any socket.error, put Excpetions to output queobj. queobj.put(e) # reset a large and random timeout to the ipaddr @@ -1916,11 +1884,14 @@ def create_connection_withopenssl(ipaddr, timeout, queobj): # close tcp socket if sock: sock.close() + finally: + if timer: + timer.cancel() def close_connection(count, queobj, first_tcp_time, first_ssl_time): for _ in range(count): sock = queobj.get() ssl_time_threshold = min(1, 1.3 * first_ssl_time) - if sock and not isinstance(sock, Exception): + if sock and hasattr(sock, 'getpeername'): if cache_key and (sock.getpeername()[0] in self.iplist_predefined or self.ssl_connection_cachesock) and sock.ssl_time < ssl_time_threshold: cache_queue = self.ssl_connection_cache[cache_key] if cache_queue.qsize() < 8: @@ -1984,14 +1955,14 @@ def reorg_ipaddrs(): errors = [] for i in range(len(addrs)): sock = queobj.get() - if not isinstance(sock, Exception): - thread.start_new_thread(close_connection, (len(addrs)-i-1, queobj, sock.tcp_time, sock.ssl_time)) + if hasattr(sock, 'getpeername'): + spawn_later(0.01, close_connection, len(addrs)-i-1, queobj, sock.tcp_time, sock.ssl_time) return sock else: errors.append(sock) if i == len(addrs) - 1: logging.warning('create_ssl_connection to %r with %s return %s, try again.', hostname, addrs, collections.OrderedDict.fromkeys(str(x) for x in errors).keys()) - if isinstance(sock, Exception): + if not hasattr(sock, 'getpeername'): raise sock def create_http_request(self, method, url, headers, body, timeout, max_retry=2, bufsize=8192, crlf=None, validate=None, cache_key=None, headfirst=False, **kwargs):