Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for query variables in preheat kernel mode #999

Merged
Prev Previous commit
Next Next commit
Move environment variables to ENV_VARIABLE
  • Loading branch information
trungleduc committed Oct 11, 2021
commit a703af8b24a4ff6ea0e1afe18c0ae126dcdd3fbf
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ install_requires =
jupyter_client>=6.1.3,<8
nbclient>=0.4.0,<0.6
nbconvert>=6.0.0,<7
websockets>=10.0,<11

[options.extras_require]
dev =
Expand Down
11 changes: 1 addition & 10 deletions voila/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from .exporter import VoilaExporter
from .shutdown_kernel_handler import VoilaShutdownKernelHandler
from .voila_kernel_manager import voila_kernel_manager_factory
from .query_parameters_handler import QueryParametersHandler, QueryStringSocketHandler
from .query_parameters_handler import QueryStringSocketHandler

_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"

Expand Down Expand Up @@ -485,15 +485,6 @@ def start(self):
])

if preheat_kernel:
handlers.append(
(
url_path_join(self.server_url, r'/voila/env/%s\/(?P<var_name>.*)' % _kernel_id_regex),
QueryParametersHandler,
{
'kernel_manager': self.kernel_manager
}
)
)
handlers.append(
(
url_path_join(self.server_url, r'/voila/query/%s' % _kernel_id_regex),
Expand Down
23 changes: 10 additions & 13 deletions voila/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._version import __version__
from .notebook_renderer import NotebookRenderer
from .query_parameters_handler import QueryStringSocketHandler
from .utils import ENV_VARIABLE


class VoilaHandler(JupyterHandler):
Expand All @@ -46,17 +47,16 @@ async def get(self, path=None):

# Adding request uri to kernel env
kernel_env = os.environ.copy()
kernel_env['SCRIPT_NAME'] = self.request.path
kernel_env[ENV_VARIABLE.SCRIPT_NAME] = self.request.path
kernel_env[
'PATH_INFO'
ENV_VARIABLE.PATH_INFO
] = '' # would be /foo/bar if voila.ipynb/foo/bar was supported
kernel_env['QUERY_STRING'] = str(self.request.query)
kernel_env['SERVER_SOFTWARE'] = 'voila/{}'.format(__version__)
kernel_env['SERVER_PROTOCOL'] = str(self.request.version)
kernel_env[ENV_VARIABLE.QUERY_STRING] = str(self.request.query)
kernel_env[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__)
kernel_env[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version)
host, port = split_host_and_port(self.request.host.lower())
kernel_env['SERVER_PORT'] = str(port) if port else ''
kernel_env['SERVER_NAME'] = host

kernel_env[ENV_VARIABLE.SERVER_PORT] = str(port) if port else ''
kernel_env[ENV_VARIABLE.SERVER_NAME] = host
# Add HTTP Headers as env vars following rfc3875#section-4.1.18
if len(self.voila_configuration.http_header_envs) > 0:
for header_name in self.request.headers:
Expand Down Expand Up @@ -145,6 +145,8 @@ def time_out():
self.write('<script>voila_heartbeat()</script>\n')
self.flush()

kernel_env[ENV_VARIABLE.VOILA_PREHEAT] = 'False'
kernel_env[ENV_VARIABLE.VOILA_BASE_URL] = self.base_url
kernel_id = await ensure_async(
(
self.kernel_manager.start_kernel(
Expand Down Expand Up @@ -186,10 +188,5 @@ def should_use_rendered_notebook(
return False
if theme is not None and rendered_theme != theme:
return False
# args_list = [
# key for key in request_args if key not in ['voila-template', 'voila-theme']
# ]
# if len(args_list) > 0:
# return False

return True
6 changes: 2 additions & 4 deletions voila/notebook_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .execute import VoilaExecutor, strip_code_cell_warnings
from .exporter import VoilaExporter
from .paths import collect_template_paths

from .utils import ENV_VARIABLE

class NotebookRenderer(LoggingConfigurable):
"""Render the notebook into HTML string."""
Expand Down Expand Up @@ -225,9 +225,7 @@ async def _jinja_kernel_start(self, nb, kernel_id, kernel_future):
await ensure_async(
self.executor.kc.execute(
f'''import os
\nos.environ["VOILA_KERNEL_ID"]="{kernel_id}"
\nos.environ["VOILA_PREHEAT"]= "{self.voila_configuration.preheat_kernel}"
\nos.environ["VOILA_BASE_URL"]="{self.base_url}"
\nos.environ["{ENV_VARIABLE.VOILA_KERNEL_ID}"]="{kernel_id}"
''',
store_history=False,
)
Expand Down
26 changes: 1 addition & 25 deletions voila/query_parameters_handler.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,31 @@
from tornado.web import RequestHandler
from tornado.websocket import WebSocketHandler
import logging


class QueryParametersHandler(RequestHandler):

def initialize(self, kernel_manager=None):
self._kernel_manager = None
if hasattr(kernel_manager, 'get_query_params'):
self._kernel_manager = kernel_manager

async def get(self, kernel_id: str, var_name: str):
if self._kernel_manager is not None:
content = self._kernel_manager.get_query_params(kernel_id, var_name)
self.finish(content[0])
else:
self.finish(None)


class QueryStringSocketHandler(WebSocketHandler):
waiters = dict()
cache = dict()

def get_compression_options(self):
# Non-None enables compression with default options.
return {}

def open(self, kernel_id):
print("open connection to", kernel_id)
QueryStringSocketHandler.waiters[kernel_id] = self
if kernel_id in self.cache:
print('sending', self.cache[kernel_id])
self.write_message(self.cache[kernel_id])

def on_close(self):
for k_id, waiter in QueryStringSocketHandler.waiters.items():
if waiter == self:
break
print('closing', k_id)
del QueryStringSocketHandler.waiters[k_id]

@classmethod
def send_updates(cls, msg):
kernel_id = msg['kernel_id']
payload = msg['payload']
print("sending message to %d waiters", kernel_id)
waiter = cls.waiters.get(kernel_id, None)
if waiter is not None:
try:
waiter.write_message(payload)
except Exception:
logging.error("Error sending message", exc_info=True)
else:
cls.cache[kernel_id] = payload
cls.cache[kernel_id] = payload
38 changes: 38 additions & 0 deletions voila/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@
#############################################################################

import os
import websockets
from typing import Awaitable
from enum import Enum


class ENV_VARIABLE(str, Enum):

VOILA_PREHEAT = 'VOILA_PREHEAT'
VOILA_KERNEL_ID = 'VOILA_KERNEL_ID'
VOILA_BASE_URL = 'VOILA_BASE_URL'
VOILA_APP_IP = 'VOILA_APP_IP'
VOILA_APP_PORT = 'VOILA_APP_PORT'
SERVER_NAME = 'SERVER_NAME'
SERVER_PORT = 'SERVER_PORT'
SCRIPT_NAME = 'SCRIPT_NAME'
PATH_INFO = 'PATH_INFO'
QUERY_STRING = 'QUERY_STRING'
SERVER_SOFTWARE = 'SERVER_SOFTWARE'
SERVER_PROTOCOL = 'SERVER_PROTOCOL'


def get_server_root_dir(settings):
Expand All @@ -24,3 +43,22 @@ def get_server_root_dir(settings):
# collapse $HOME to ~
root_dir = '~' + root_dir[len(home):]
return root_dir


async def get_user_query(url: str = None) -> Awaitable:
if url is None:
base_url = os.getenv(ENV_VARIABLE.VOILA_BASE_URL, '/')
server_ip = os.getenv(ENV_VARIABLE.VOILA_APP_IP, '127.0.0.1')
server_port = os.getenv(ENV_VARIABLE.VOILA_APP_PORT, '8866')
url = f'ws://{server_ip}:{server_port}{base_url}voila/query'

preheat_mode = os.getenv(ENV_VARIABLE.VOILA_PREHEAT, 'False')
kernel_id = os.getenv(ENV_VARIABLE.VOILA_KERNEL_ID)
ws_url = f'{url}/{kernel_id}'

if preheat_mode == 'True':
async with websockets.connect(ws_url) as websocket:
qs = await websocket.recv()
else:
qs = os.getenv(ENV_VARIABLE.QUERY_STRING)
return qs
4 changes: 3 additions & 1 deletion voila/voila_kernel_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nbclient.util import ensure_async
import re
from .notebook_renderer import NotebookRenderer
from .utils import ENV_VARIABLE

T = TypeVar('T')

Expand Down Expand Up @@ -195,6 +196,8 @@ def fill_if_needed(
for key in kernel_env_variables:
if key not in kernel_env:
kernel_env[key] = kernel_env_variables[key]
kernel_env[ENV_VARIABLE.VOILA_BASE_URL] = self.parent.base_url
kernel_env[ENV_VARIABLE.VOILA_PREHEAT] = 'True'
kwargs['env'] = kernel_env

heated = len(pool)
Expand Down Expand Up @@ -287,7 +290,6 @@ async def _initialize(

kernel_future = self.get_kernel(kernel_id)
task = asyncio.get_event_loop().create_task(renderer.generate_content_hybrid(kernel_id, kernel_future))
task.add_done_callback(lambda _: print('done', kernel_id))
return {'task': task, 'renderer': renderer, 'kernel_id': kernel_id}

async def cull_kernel_if_idle(self, kernel_id: str):
Expand Down