Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ONNX exports for Optimum compatible models #31311

Merged
merged 13 commits into from
Jun 27, 2024

Conversation

merveenoyan
Copy link
Contributor

@merveenoyan merveenoyan commented Jun 7, 2024

@amyeroberts as discussed and also pinging @xenova for review :') (who also fixed DPT)

I prioritized Optimum compatible ones because I'm launching a project where there's Optimum examples for vision models. I will have a separate PR for the models that aren't compatible with Optimum. Rest of the Optimum compatible models export well without a problem.

Comment on lines 229 to 231
def safe_int(x):
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
old_grid_size = safe_int(posemb_grid.size(0) ** 0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be necessary to extract this cast function to a utility file, and then re-use across different models. I have a feeling that other instances, like:
image

will break execution, since during inference, it may require normal python types

Comment on lines 440 to 441
new_height = (torch.ceil(orig_height / patch_height) * patch_height).to(torch.int64)
new_width = (torch.ceil(orig_width / patch_width) * patch_width).to(torch.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above - doesn't interpolate require (int, int) when not tracing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check tracing, thanks for the heads up

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@merveenoyan merveenoyan requested a review from xenova June 7, 2024 13:27
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Thanks for fixing this for all these models ❤️

Just a few small comments

src/transformers/utils/generic.py Outdated Show resolved Hide resolved
@@ -750,3 +750,15 @@ def infer_framework(model_class):
return "flax"
else:
raise TypeError(f"Could not infer framework from class {model_class}.")


def safe_int(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstrings would be helpful here e.g. for inspecting in IDEs: what does it mean for an int to be safe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed a better name is probably a good idea 😅 I called it safe_int in a way to "safely cast some value (which could be a python number or tensor) to an integer in a way that respects tracing"

Copy link
Contributor Author

@merveenoyan merveenoyan Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll swap with torch_int and torch_float

new_width = int(math.ceil(orig_width / patch_width) * patch_width)
new_height = (
safe_float(torch.ceil(orig_height / patch_height) * patch_height)
if torch.jit.is_tracing()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the conditional here? This is already handled in the safe_float and safe_int functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's required for torch.ceil no?

Copy link
Collaborator

@amyeroberts amyeroberts Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh, I don't know, is there a reason we couldn't usetorch.ceil directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm passing an int or float, torch.ceil will be called first and it will fail because torch.ceil can only be called with tensors AFAIK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only other Q here then is why do we use a float when tracing and int otherwise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I think I was mistaken with that one, you're right, I fixed it :)

@merveenoyan
Copy link
Contributor Author

@amyeroberts the failing tests seem irrelevant to this PR, I can't re-run them, can you re-run?

@amyeroberts
Copy link
Collaborator

@merveenoyan si si - done!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing all of these!

Just for my own understanding - it there any reason to not use the torch compatible float/int when were not tracing?

@merveenoyan
Copy link
Contributor Author

@amyeroberts to my understanding, torch ONNX export internally calls trace thus to capture the graph properly all needs to be torch tensor, one needs to check tracing, and convert to torch tensor/convert ops internally if it's not torch

@merveenoyan
Copy link
Contributor Author

@amyeroberts can you merge if you think it's ok?

@amyeroberts
Copy link
Collaborator

@amyeroberts to my understanding, torch ONNX export internally calls trace thus to capture the graph properly all needs to be torch tensor, one needs to check tracing, and convert to torch tensor/convert ops internally if it's not torch

Right, I see why we need to do it for the onnx export, but for day-to-day use could we just use torch primitives instead of a python int or float i.e. do we need to maintain this if/else structure or can we switch everything to torch land regardless of whether we're tracing or not?

@merveenoyan
Copy link
Contributor Author

@amyeroberts I guess if it's just torch modelling code then yes. Would you like me to swap everything?

@merveenoyan
Copy link
Contributor Author

also asking the same question to @xenova

@amyeroberts
Copy link
Collaborator

Would you like me to swap everything?

@merveenoyan Yes please! This will be cleaner and easier to follow in the code :)

@xenova
Copy link
Contributor

xenova commented Jun 21, 2024

also asking the same question to @xenova

I agree with @amyeroberts - if there is a way to "do everything in torch land", that's the best solution! However, there are cases where I'm not entirely sure how to do this. For example, with torch.nn.interpolate:

  • during inference, the sizes/scales CANNOT be tensors. They must be python ints/floats.
  • during tracing, the sizes/scales SHOULD be tensors. Dynamic shapes usually break if they are python ints/floats due to the loss of tracing.

See here for example code (DinoV2 backbone):

        if torch.jit.is_tracing():
          sqrt_N = N ** 0.5
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
              size=(w0, h0),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )
        else:
          sqrt_N = math.sqrt(N)
          sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
              scale_factor=(sx, sy),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )

Very ugly... I know :/

@merveenoyan
Copy link
Contributor Author

@xenova sounds good, very glad to work with you tbh I didn't know that it would be required in inference.
@amyeroberts should we merge?

@amyeroberts
Copy link
Collaborator

@merveenoyan My understanding from above was that the PR would be updated to remove all the if/else structures wherever possible (but as @xenova points out isn't everywhere unfortunately)

@merveenoyan
Copy link
Contributor Author

@amyeroberts from what I understood we should still keep them in if/else not to break the inference (I'm also scared of edge cases if there is etc) so I'd rather keep them. what I can do is to test all of them to see if they break or not when all are tensors and remove where it doesn't have to be a python type

@amyeroberts
Copy link
Collaborator

@merveenoyan OK. Let's just merge then and we can follow up in future PRs 👍

@amyeroberts amyeroberts merged commit c9f191a into huggingface:main Jun 27, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants