-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds in ability to cache low res logits from prompts #582
Conversation
Not ready for review yet |
Ready for review @PawelPeczek-Roboflow |
inference/core/env.py
Outdated
@@ -283,6 +283,8 @@ | |||
|
|||
# Maximum embedding cache size for SAM, default is 10 | |||
SAM_MAX_EMBEDDING_CACHE_SIZE = int(os.getenv("SAM_MAX_EMBEDDING_CACHE_SIZE", 10)) | |||
# The sam2 low_res_masks are the biggest memory usage, and 1000 of them take 256*256*4*1000/1024/1024 MB = 250MB | |||
SAM2_MAX_CACHE_SIZE = int(os.getenv("SAM_MAX_EMBEDDING_CACHE_SIZE", 1000)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose to decouple SAM from SAM2 in terms of env variables configuring model
plus let's enlist the config variable in this page
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is fundamental problem I detected - algorithm used in find_prior_prompt_in_cache(...)
is computationally intractable in scenario of not finding hit in cache.
code to reproduce:
import time
from inference.core.entities.requests.sam2 import (
Sam2EmbeddingRequest,
Sam2InferenceRequest,
Sam2Prompt,
Sam2PromptSet,
Box, Point,
Sam2SegmentationRequest,
)
from inference.models.sam2.segment_anything2 import find_prior_prompt_in_cache
initial_prompt_set = Sam2PromptSet(
prompts=[Sam2Prompt(
box=Box(x=10, y=10, height=10, width=10,),
points=[Point(x=10, y=10, positive=True)] * 5
)] * 3
)
start = time.time()
find_prior_prompt_in_cache_local(
initial_prompt_set=initial_prompt_set,
image_id="some",
cache={},
)
print(f"Duration: {(time.time() - start) * 1000}ms")
basically you find stack growing exponentially.
…into sam2-id-free-caching
@PawelPeczek-Roboflow ready for rereview The problem we're trying to solve is this -- Suppose you have a prompt with n points. You get a mask, and you want to negative prompt based on this mask. You have to feed that mask back in when adding the n+1st point. So when we receive the request with n+1 points, we need to go load the mask from the n point prompt. The reason we want to cache these masks is because they take ~300kb and we don't want to incur the latency of serialization/deserialization into np arrays, as well as network latency. This matters because smart poly is trying to run in real time for image previews. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, three minor changes:
- please add global env flag disabling functionality at server level and descriptions into changed fields in request - describing that functionality may be disabled based on server config
save_logits_to_cache: bool = Field(default=False)
load_logits_from_cache: bool = Field(default=False)
-
self.low_res_logits_cache: LogitsCacheType = {}
typing is wrong - this idDict[Tuple[str, str], LogitsCacheType]
-
return type of
find_prior_prompt_in_cache(...)
is probablyOptional[np.ndarray]
@PawelPeczek-Roboflow thanks for helping me implement the fixes you suggested. |
93241f3
Description
Performs breadth first search to find the most similar prompt for loading cached logits.
Performance appears to be better when using cached logits. See no cached logits:
vs:
Additionally, adds logic to pad the input points to the sam model so that 2 different prompts with differeing number of poitns can be used.
Type of change
Please delete options that are not relevant.
How has this change been tested, please provide a testcase or example of how you tested the change?
Locally, integration tests
Any specific deployment considerations
For example, documentation changes, usability, usage/costs, secrets, etc.
Docs