Skip to content

Commit

Permalink
updates axo, vllm (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesfrye committed Jul 6, 2024
1 parent 7923eb7 commit be0306c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
MINUTES = 60 # seconds
HOURS = 60 * MINUTES

# Axolotl image hash corresponding to main-20240522-py3.11-cu121-2.2.2
# Axolotl image hash corresponding to main-20240705-py3.11-cu121-2.3.0
AXOLOTL_REGISTRY_SHA = (
"8ec2116dd36ecb9fb23702278ac612f27c1d4309eca86ad0afd3a3fe4a80ad5b"
"9578c47333bdcc9ad7318e54506b9adaf283161092ae780353d506f7a656590a"
)

ALLOW_WANDB = os.environ.get("ALLOW_WANDB", "false").lower() == "true"

axolotl_image = (
modal.Image.from_registry(f"winglian/axolotl@sha256:{AXOLOTL_REGISTRY_SHA}")
.pip_install(
"huggingface_hub==0.20.3",
"huggingface_hub==0.23.2",
"hf-transfer==0.1.5",
"wandb==0.16.3",
"fastapi==0.110.0",
Expand All @@ -38,8 +38,8 @@
vllm_image = modal.Image.from_registry(
"nvidia/cuda:12.1.0-base-ubuntu22.04", add_python="3.10"
).pip_install(
"vllm==0.2.6",
"torch==2.1.2",
"vllm==0.5.0post1",
"torch==2.3.0",
"numpy<2", # To avoid vLLM ecosystem compatibility issues
)

Expand Down
5 changes: 5 additions & 0 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def init(self):
model=model_path,
gpu_memory_utilization=0.95,
tensor_parallel_size=N_INFERENCE_GPUS,
disable_custom_all_reduce=True, # brittle as of v0.5.0
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

Expand Down Expand Up @@ -123,11 +124,15 @@ async def web(self, input: str):

@modal.exit()
def stop_engine(self):
print("stopping")
if N_INFERENCE_GPUS > 1:
import ray

ray.shutdown()

# access private attribute to ensure graceful termination
self.engine._background_loop_unshielded.cancel()


@app.local_entrypoint()
def inference_main(run_name: str = "", prompt: str = ""):
Expand Down

0 comments on commit be0306c

Please sign in to comment.