Skip to content

Commit

Permalink
timeout more quickly
Browse files Browse the repository at this point in the history
  • Loading branch information
Phus Lu committed Oct 17, 2014
1 parent 85e63ec commit c30abd5
Showing 1 changed file with 25 additions and 54 deletions.
79 changes: 25 additions & 54 deletions local/proxylib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


gevent = sys.modules.get('gevent') or logging.warn('please enable gevent.')
NetWorkError = (socket.error, ssl.SSLError, OpenSSL.SSL.Error, OSError)


try:
Expand Down Expand Up @@ -809,7 +808,7 @@ def __io_copy(dest, source, timeout):
dest.sendall(data)
except socket.timeout:
pass
except NetWorkError as e:
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.ENOTCONN, errno.EPIPE):
raise
if e.args[0] in (errno.EBADF,):
Expand All @@ -824,43 +823,6 @@ def __io_copy(dest, source, timeout):
__io_copy(local, remote, timeout)


def deprecated_forward_socket(local, remote, timeout, bufsize):
"""deprecated forward socket"""
try:
tick = 1
timecount = timeout
while 1:
timecount -= tick
if timecount <= 0:
break
(ins, _, errors) = select.select([local, remote], [], [local, remote], tick)
if errors:
break
for sock in ins:
data = sock.recv(bufsize)
if not data:
break
if sock is remote:
local.sendall(data)
timecount = timeout
else:
remote.sendall(data)
timecount = timeout
except socket.timeout:
pass
except NetWorkError as e:
if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.ENOTCONN, errno.EPIPE):
raise
if e.args[0] in (errno.EBADF,):
return
finally:
for sock in (remote, local):
try:
sock.close()
except StandardError:
pass


class LocalProxyServer(SocketServer.ThreadingTCPServer):
"""Local Proxy Server"""
request_queue_size = 1024
Expand All @@ -876,15 +838,15 @@ def close_request(self, request):
def finish_request(self, request, client_address):
try:
self.RequestHandlerClass(request, client_address, self)
except NetWorkError as e:
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE):
raise

def handle_error(self, *args):
"""make ThreadingTCPServer happy"""
exc_info = sys.exc_info()
error = exc_info and len(exc_info) and exc_info[1]
if isinstance(error, NetWorkError) and len(error.args) > 1 and 'bad write retry' in error.args[1]:
if isinstance(error, (socket.error, ssl.SSLError, OpenSSL.SSL.Error)) and len(error.args) > 1 and 'bad write retry' in error.args[1]:
exc_info = error = None
else:
del exc_info, error
Expand Down Expand Up @@ -1003,7 +965,7 @@ def handle(self, handler, do_ssl_handshake=True):
if e[0] == -1 and 'Unexpected EOF' in e[1]:
return
raise
except NetWorkError as e:
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET):
logging.exception('ssl.wrap_socket(connection=%r) failed: %s', handler.connection, e)
return
Expand All @@ -1027,15 +989,15 @@ def handle(self, handler, do_ssl_handshake=True):
handler.send_error(400)
handler.wfile.close()
return
except NetWorkError as e:
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e.args[0] in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE):
handler.close_connection = 1
return
else:
raise
try:
handler.do_METHOD()
except NetWorkError as e:
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e.args[0] not in (errno.ECONNABORTED, errno.ETIMEDOUT, errno.EPIPE):
raise

Expand Down Expand Up @@ -1460,7 +1422,7 @@ def finish(self):
"""make python2 BaseHTTPRequestHandler happy"""
try:
BaseHTTPServer.BaseHTTPRequestHandler.finish(self)
except NetWorkError as e:
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE):
raise

Expand Down Expand Up @@ -1677,7 +1639,7 @@ def close_connection(count, queobj, first_tcp_time):
for _ in range(count):
sock = queobj.get()
tcp_time_threshold = min(1, 1.3 * first_tcp_time)
if sock and not isinstance(sock, Exception):
if sock and hasattr(sock, 'getpeername'):
if cache_key and (sock.getpeername()[0] in self.iplist_predefined or self.tcp_connection_cachesock) and sock.tcp_time < tcp_time_threshold:
cache_queue = self.tcp_connection_cache[cache_key]
if cache_queue.qsize() < 8:
Expand Down Expand Up @@ -1736,13 +1698,13 @@ def reorg_ipaddrs():
thread.start_new_thread(create_connection, (addr, timeout, queobj))
for i in range(len(addrs)):
sock = queobj.get()
if not isinstance(sock, Exception):
thread.start_new_thread(close_connection, (len(addrs)-i-1, queobj, getattr(sock, 'tcp_time') or self.tcp_connection_time[sock.getpeername()]))
if hasattr(sock, 'getpeername'):
spawn_later(0.01, close_connection, len(addrs)-i-1, queobj, getattr(sock, 'tcp_time') or self.tcp_connection_time[sock.getpeername()])
return sock
elif i == 0:
# only output first error
logging.warning('create_tcp_connection to %r with %s return %r, try again.', hostname, addrs, sock)
if isinstance(sock, Exception):
if not hasattr(sock, 'getpeername'):
raise sock

def create_ssl_connection(self, hostname, port, timeout, **kwargs):
Expand Down Expand Up @@ -1837,6 +1799,12 @@ def create_connection(ipaddr, timeout, queobj):
def create_connection_withopenssl(ipaddr, timeout, queobj):
sock = None
ssl_sock = None
timer = None
NetworkError = (socket.error, OpenSSL.SSL.Error, OSError)
if gevent:
NetworkError += (gevent.Timeout,)
timer = gevent.Timeout(timeout)
timer.start()
try:
# create a ipv4/ipv6 socket object
sock = socket.socket(socket.AF_INET if ':' not in ipaddr[0] else socket.AF_INET6)
Expand Down Expand Up @@ -1899,7 +1867,7 @@ def create_connection_withopenssl(ipaddr, timeout, queobj):
response.close()
# put ssl socket object to output queobj
queobj.put(ssl_sock)
except (socket.error, OpenSSL.SSL.Error, OSError) as e:
except NetworkError as e:
# any socket.error, put Excpetions to output queobj.
queobj.put(e)
# reset a large and random timeout to the ipaddr
Expand All @@ -1916,11 +1884,14 @@ def create_connection_withopenssl(ipaddr, timeout, queobj):
# close tcp socket
if sock:
sock.close()
finally:
if timer:
timer.cancel()
def close_connection(count, queobj, first_tcp_time, first_ssl_time):
for _ in range(count):
sock = queobj.get()
ssl_time_threshold = min(1, 1.3 * first_ssl_time)
if sock and not isinstance(sock, Exception):
if sock and hasattr(sock, 'getpeername'):
if cache_key and (sock.getpeername()[0] in self.iplist_predefined or self.ssl_connection_cachesock) and sock.ssl_time < ssl_time_threshold:
cache_queue = self.ssl_connection_cache[cache_key]
if cache_queue.qsize() < 8:
Expand Down Expand Up @@ -1984,14 +1955,14 @@ def reorg_ipaddrs():
errors = []
for i in range(len(addrs)):
sock = queobj.get()
if not isinstance(sock, Exception):
thread.start_new_thread(close_connection, (len(addrs)-i-1, queobj, sock.tcp_time, sock.ssl_time))
if hasattr(sock, 'getpeername'):
spawn_later(0.01, close_connection, len(addrs)-i-1, queobj, sock.tcp_time, sock.ssl_time)
return sock
else:
errors.append(sock)
if i == len(addrs) - 1:
logging.warning('create_ssl_connection to %r with %s return %s, try again.', hostname, addrs, collections.OrderedDict.fromkeys(str(x) for x in errors).keys())
if isinstance(sock, Exception):
if not hasattr(sock, 'getpeername'):
raise sock

def create_http_request(self, method, url, headers, body, timeout, max_retry=2, bufsize=8192, crlf=None, validate=None, cache_key=None, headfirst=False, **kwargs):
Expand Down

0 comments on commit c30abd5

Please sign in to comment.