Skip to content

Commit

Permalink
request:enable_response_body() method
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Aug 16, 2016
1 parent 577f839 commit 78ab179
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 20 deletions.
14 changes: 14 additions & 0 deletions docs/scripting-request-object.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ Drop the request.

**Async:** no.

.. _splash-request-enable-response-body:

request:enable_response_body
----------------------------

Enable tracking of response content (i.e. :ref:`splash-response-body`
attribute).

**Signature:** ``request:enable_response_body()``

**Returns:** nil.

**Async:** no.

.. _splash-request-set-url:

request:set_url
Expand Down
46 changes: 33 additions & 13 deletions splash/network_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ class ProxiedQNetworkAccessManager(QNetworkAccessManager):
* Tracks information about requests/responses and stores it in HAR format,
including response content.
* Allows to set per-request timeouts.
"""

_REQUEST_ID = QNetworkRequest.User + 1
_SHOULD_TRACK = QNetworkRequest.User + 2

def __init__(self, verbosity):
super(ProxiedQNetworkAccessManager, self).__init__()
Expand All @@ -96,7 +95,7 @@ def __init__(self, verbosity):
self._default_proxy = self.proxy()
self.cookiejar = SplashCookieJar(self)
self.setCookieJar(self.cookiejar)
self._response_content = {} # requestId => response content
self._response_bodies = {} # requestId => response content
self._request_ids = itertools.count()
assert self.proxyFactory() is None, "Standard QNetworkProxyFactory is not supported"

Expand Down Expand Up @@ -126,6 +125,7 @@ def createRequest(self, operation, request, outgoingData=None):
request, operation, outgoingData)

self._handle_custom_proxies(request)
self._handle_request_response_tracking(request)

har = self._get_har(request)
if har is not None:
Expand All @@ -151,7 +151,11 @@ def createRequest(self, operation, request, outgoingData=None):

reply.error.connect(self._on_reply_error)
reply.finished.connect(self._on_reply_finished)
reply.readyRead.connect(self._on_reply_ready_read)

if self._should_track_content(request):
self._response_bodies[req_id] = QByteArray()
reply.readyRead.connect(self._on_reply_ready_read)

reply.metaDataChanged.connect(self._on_reply_headers)
reply.downloadProgress.connect(self._on_reply_download_progress)
return reply
Expand Down Expand Up @@ -190,8 +194,9 @@ def _wrap_request(self, request):
req = QNetworkRequest(request)
req_id = next(self._request_ids)
req.setAttribute(self._REQUEST_ID, req_id)
if hasattr(request, 'timeout'):
req.timeout = request.timeout
for attr in ['timeout', 'track']:
if hasattr(request, attr):
setattr(req, attr, getattr(request, attr))
return req, req_id

def _handle_custom_proxies(self, request):
Expand Down Expand Up @@ -232,6 +237,13 @@ def _handle_custom_headers(self, request):
def _handle_request_cookies(self, request):
self.cookiejar.update_cookie_header(request)

def _handle_request_response_tracking(self, request):
if hasattr(request, 'track'):
request.setAttribute(self._SHOULD_TRACK, request.track)
else:
# FIXME
request.setAttribute(self._SHOULD_TRACK, False)

def _handle_reply_cookies(self, reply):
self.cookiejar.fill_from_reply(reply)

Expand All @@ -240,6 +252,9 @@ def _get_request_id(self, request=None):
request = self.sender().request()
return request.attribute(self._REQUEST_ID)

def _should_track_content(self, request):
return request.attribute(self._SHOULD_TRACK)

def _get_har(self, request=None):
"""
Return HarBuilder instance.
Expand All @@ -260,7 +275,7 @@ def _set_webpage_attribute(self, request, attribute, value):
return setattr(web_frame.page(), attribute, value)

def _on_reply_error(self, error_id):
self._response_content.pop(self._get_request_id(), None)
self._response_bodies.pop(self._get_request_id(), None)

if error_id != QNetworkReply.OperationCanceledError:
error_msg = REQUEST_ERRORS.get(error_id, 'unknown error')
Expand All @@ -269,11 +284,16 @@ def _on_reply_error(self, error_id):

def _on_reply_ready_read(self):
reply = self.sender()
req_id = self._get_request_id()
self._store_response_chunk(reply)

def _store_response_chunk(self, reply):
req_id = self._get_request_id(reply.request())
if req_id not in self._response_bodies:
self.log("Internal problem in _store_response_chunk: "
"request %s is not tracked" % req_id, reply, min_level=1)
return
chunk = reply.peek(reply.bytesAvailable())
if req_id not in self._response_content:
self._response_content[req_id] = QByteArray()
self._response_content[req_id].append(chunk)
self._response_bodies[req_id].append(chunk)

def _on_reply_finished(self):
reply = self.sender()
Expand All @@ -284,8 +304,8 @@ def _on_reply_finished(self):
if har is not None:
req_id = self._get_request_id()
# FIXME: what if har is None? When can it be None?
# Who removes the content from self._response_content dict?
content = self._response_content.pop(req_id, None)
# Who removes the content from self._response_bodies dict?
content = self._response_bodies.pop(req_id, None)
if content is not None:
content = bytes(content)

Expand Down
19 changes: 14 additions & 5 deletions splash/qtrender_lua.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,11 @@ def _on_request_required(self, meth, attr_name):
def abort(self):
drop_request(self.request)

@command()
@requires_request
def enable_response_body(self):
self.request.track = True

@command()
@requires_request
def set_url(self, url):
Expand Down Expand Up @@ -1388,18 +1393,22 @@ def __init__(self, lua, exceptions, reply, exposed_request,
self._content = content
self._info = resp_info
self._info_lua = None
self._body_binary = None

@lua_property("body")
@command()
def get_body(self):
if self._body_binary is None:
body = self._content or get_response_body_bytes(self._info)
content_type = self._info['content']['mimeType']
self._body_binary = BinaryCapsule(body, content_type)
if not hasattr(self, '_body_binary'):
self._body_binary = self._get_body_object()
self._content = None
return self._body_binary

def _get_body_object(self):
body = self._content or get_response_body_bytes(self._info)
if body is None:
return None
content_type = self._info['content']['mimeType']
return BinaryCapsule(body, content_type)

@lua_property("info")
@command()
def get_info(self):
Expand Down
1 change: 1 addition & 0 deletions splash/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def _on_render_error(self, failure, request):

def _on_internal_error(self, failure, request):
log.err()
# failure.printTraceback()
sentry.capture(failure)
# only propagate str value to avoid exposing internal details
ex = InternalError(str(failure.value))
Expand Down
40 changes: 38 additions & 2 deletions splash/tests/test_execute_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from PIL import Image
import six
from six.moves.urllib.parse import urlencode
import requests
import pytest

from splash.exceptions import ScriptError
Expand Down Expand Up @@ -215,7 +216,6 @@ def test_set_header(self):
""", {'url': self.mockurl("getrequest")})
self.assertStatusCode(resp, 200)


if six.PY3:
self.assertIn("b'custom-header': b'some-val'", resp.text)
self.assertIn("b'user-agent': b'Fooozilla'", resp.text)
Expand Down Expand Up @@ -284,6 +284,42 @@ def test_request_attrs(self):
self.assertEqual(req['method'], 'GET')
self.assertIn('Accept', req['headers'])

def test_enable_response_body(self):
url = self.mockurl('show-image')
resp = self.request_lua("""
function main(splash)
splash:on_request(function(req)
if req.url:find(".gif") ~= nil then
req:enable_response_body()
end
end)
local bodies = {}
splash:on_response(function(resp, req)
bodies[resp.url] = resp.body
end)
assert(splash:go(splash.args.url))
return {har=splash:har(), bodies=bodies}
end
""", {'url': url})
self.assertStatusCode(resp, 200)
data = resp.json()

bodies = data['bodies']
assert len(bodies) == 1
url = list(bodies.keys())[0]
assert "slow.gif" in url
img_gif = requests.get(self.mockurl("slow.gif?n=0")).content
body = base64.b64decode(bodies[url])
assert body == img_gif

entries = data['har']['log']['entries']
assert len(entries) == 2
assert 'text' not in entries[0]['response']['content']
assert entries[1]['response']['content']['encoding'] == 'base64'
assert entries[1]['response']['content']['text'] == bodies[url]


class OnResponseHeadersTest(BaseLuaRenderTest, BaseHtmlProxyTest):
def test_get_header(self):
Expand Down Expand Up @@ -1041,4 +1077,4 @@ def test_function_returns_several_values(self):
end
""")
self.assertStatusCode(resp, 200)
self.assertEqual(resp.json(), [1, 2, 3])
self.assertEqual(resp.json(), [1, 2, 3])

0 comments on commit 78ab179

Please sign in to comment.