Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds in ability to cache low res logits from prompts #582

Merged
merged 30 commits into from
Aug 22, 2024

Conversation

probicheaux
Copy link
Collaborator

Description

Performs breadth first search to find the most similar prompt for loading cached logits.

Performance appears to be better when using cached logits. See no cached logits:
sam_negative_prompted_small_no_mask
vs:
sam_negative_prompted

Additionally, adds logic to pad the input points to the sam model so that 2 different prompts with differeing number of poitns can be used.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

How has this change been tested, please provide a testcase or example of how you tested the change?

Locally, integration tests

Any specific deployment considerations

For example, documentation changes, usability, usage/costs, secrets, etc.

Docs

  • Docs updated? What were the changes:

@probicheaux
Copy link
Collaborator Author

Not ready for review yet

@probicheaux probicheaux marked this pull request as draft August 16, 2024 01:31
@probicheaux probicheaux marked this pull request as ready for review August 19, 2024 17:53
@probicheaux
Copy link
Collaborator Author

Ready for review @PawelPeczek-Roboflow

@@ -283,6 +283,8 @@

# Maximum embedding cache size for SAM, default is 10
SAM_MAX_EMBEDDING_CACHE_SIZE = int(os.getenv("SAM_MAX_EMBEDDING_CACHE_SIZE", 10))
# The sam2 low_res_masks are the biggest memory usage, and 1000 of them take 256*256*4*1000/1024/1024 MB = 250MB
SAM2_MAX_CACHE_SIZE = int(os.getenv("SAM_MAX_EMBEDDING_CACHE_SIZE", 1000))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose to decouple SAM from SAM2 in terms of env variables configuring model
plus let's enlist the config variable in this page

Copy link
Collaborator

@PawelPeczek-Roboflow PawelPeczek-Roboflow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is fundamental problem I detected - algorithm used in find_prior_prompt_in_cache(...) is computationally intractable in scenario of not finding hit in cache.

code to reproduce:

import time

from inference.core.entities.requests.sam2 import (
    Sam2EmbeddingRequest,
    Sam2InferenceRequest,
    Sam2Prompt,
    Sam2PromptSet,
    Box, Point,
    Sam2SegmentationRequest,
)
from inference.models.sam2.segment_anything2 import find_prior_prompt_in_cache

initial_prompt_set = Sam2PromptSet(
    prompts=[Sam2Prompt(
        box=Box(x=10, y=10, height=10, width=10,),
        points=[Point(x=10, y=10, positive=True)] * 5
    )] * 3
)


start = time.time()
find_prior_prompt_in_cache_local(
    initial_prompt_set=initial_prompt_set,
    image_id="some",
    cache={},
)
print(f"Duration: {(time.time() - start) * 1000}ms")

basically you find stack growing exponentially.

@probicheaux
Copy link
Collaborator Author

@PawelPeczek-Roboflow ready for rereview

The problem we're trying to solve is this -- Suppose you have a prompt with n points. You get a mask, and you want to negative prompt based on this mask. You have to feed that mask back in when adding the n+1st point. So when we receive the request with n+1 points, we need to go load the mask from the n point prompt.

The reason we want to cache these masks is because they take ~300kb and we don't want to incur the latency of serialization/deserialization into np arrays, as well as network latency. This matters because smart poly is trying to run in real time for image previews.

Copy link
Collaborator

@PawelPeczek-Roboflow PawelPeczek-Roboflow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, three minor changes:

  • please add global env flag disabling functionality at server level and descriptions into changed fields in request - describing that functionality may be disabled based on server config
    save_logits_to_cache: bool = Field(default=False)
    load_logits_from_cache: bool = Field(default=False)
  • self.low_res_logits_cache: LogitsCacheType = {} typing is wrong - this id Dict[Tuple[str, str], LogitsCacheType]

  • return type of find_prior_prompt_in_cache(...) is probably Optional[np.ndarray]

@tonylampada
Copy link
Contributor

@PawelPeczek-Roboflow thanks for helping me implement the fixes you suggested.
I just pushed them.

@PawelPeczek-Roboflow PawelPeczek-Roboflow dismissed stale reviews from grzegorz-roboflow and themself via 93241f3 August 22, 2024 16:51
@PawelPeczek-Roboflow PawelPeczek-Roboflow merged commit d3be171 into main Aug 22, 2024
58 checks passed
@PawelPeczek-Roboflow PawelPeczek-Roboflow deleted the sam2-id-free-caching branch August 22, 2024 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants