Skip to content

Commit

Permalink
add share option
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
doyend committed Oct 5, 2022
1 parent a5c8339 commit 125e499
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ prompts: # prompts for GPT inference
server: False # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ batch_size: 16
server: False
port: 5555
web_server: False # whether launch the web inference server
share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def main(cfg) -> None:
if cfg.server:
if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
thread = threading.Thread(target=get_demo, daemon=True, args=(cfg.username, cfg.password))
thread = threading.Thread(target=get_demo, daemon=True, args=(cfg.share, cfg.username, cfg.password))
thread.start()
server = MegatronServer(model.cuda())
server.run("0.0.0.0", port=cfg.port)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def placeholder():
if cfg.server:
if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
thread = threading.Thread(target=get_demo, daemon=True, args=(cfg.username, cfg.password))
thread = threading.Thread(target=get_demo, daemon=True, args=(cfg.share, cfg.username, cfg.password))
thread.start()
server = MegatronServer(model.cuda())
server.run("0.0.0.0", port=cfg.port)
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/modules/common/megatron_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_
return sentences[0]


def get_demo(username, password):
def get_demo(share, username, password):
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2, width=200):
Expand All @@ -62,4 +62,4 @@ def get_demo(username, password):
output_box = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
btn.click(get_generation, inputs=[input_prompt, greedy_flag, add_BOS, token_to_gen, min_token_to_gen, temperature, top_p, top_k, repetition_penality], outputs=[output_box])
demo.launch(share=True, server_port=13570, server_name='0.0.0.0', auth=(username, password))
demo.launch(share=share, server_port=13570, server_name='0.0.0.0', auth=(username, password))

0 comments on commit 125e499

Please sign in to comment.