Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deepspeed prefix-lm #107

Merged
merged 11 commits into from
Oct 10, 2021
Merged

Fix deepspeed prefix-lm #107

merged 11 commits into from
Oct 10, 2021

Conversation

thomasw21
Copy link
Member

@thomasw21 thomasw21 commented Sep 17, 2021

deepspeed implementation would prevent one from passing down an attention mask. We need this feature for prefix lm which has an attention mask that's batch dependent.

Maybe worth adding this commit to #10

@@ -210,6 +218,10 @@ def _to_float16(inputs):
self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal))


if not hasattr(args, 'attn_mask'):
# We drop attention mask from the pipeline
self.specs.append(lambda x: x[0])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we drop attention here ....

@@ -290,7 +290,6 @@ def forward(self, inputs, **kwargs):
if hasattr(self._args, 'attn_mask'):
return embeddings
else:
assert False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We remove this in order to allow this case.

Comment on lines +68 to +74
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We move this part of the code after setting attention mask in args. This allows us to use args while building the model, typically to distinguish gpt vs prefix-lm.

@thomasw21 thomasw21 changed the title Fix deepspeed prefix Fix deepspeed prefix-lm Sep 17, 2021
@thomasw21 thomasw21 changed the title Fix deepspeed prefix-lm WIP: Fix deepspeed prefix-lm Sep 17, 2021
@thomasw21 thomasw21 marked this pull request as draft September 17, 2021 13:31
@thomasw21 thomasw21 removed the request for review from stas00 September 17, 2021 13:31
@stas00
Copy link
Member

stas00 commented Sep 17, 2021

@thomasw21, is there a way to fix it directly in deepspeed? or do you feel here is good enough?

@thomasw21
Copy link
Member Author

I think we're going to have to fix something in deepspeed somewhere. I'm still unclear on the errors I'm getting ...

https://huggingface.slack.com/archives/C01NHER1JLS/p1631861191377400

@@ -33,6 +33,9 @@
import os
import subprocess

class GPT2ModelPipe(GPTModelPipe):
Copy link
Member Author

@thomasw21 thomasw21 Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a terrible hack.

deepspeed doesn't handle well passing bool tensors, they hack their way for some modules GPT2ModelPipe.
Example: https://github.com/microsoft/DeepSpeed/blob/big-science/deepspeed/runtime/pipe/engine.py#L857

My own fork of deepspeed necessary to run the model
https://github.com/thomasw21/DeepSpeed/tree/big-science-fix-passing-multiple-tensors
The fork basically allows to pass tuple of tensors instead of a single tensor between stages.

Some more long term solutions I can think of:

  • convert mask from BoolTensor to Int16Tensor? Hopefully we don't need to handle such behaviours in deepspeed, we just have to manually convert int64 to bool whenever we use it
  • find a better way to handle BoolTensor directly in DS? Typically maybe we can pass down dtype as tensor meta in order to generate the correct buffer.
  • I think the original reason why there's this issue, is torch.distributed doesn't handle bool tensors very well, but I think it's fixed in [distributed] NCCL Backend doesn't support torch.bool data type pytorch/pytorch#24137 . Maybe since our setup has a higher version, maybe we can remove all code linked to bool tensors?

cc @stas00 @ShadenSmith (I might be completely off right here in my understanding, feel free to correct me)

Copy link
Member

@stas00 stas00 Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could explain - how does your hack work, @thomasw21?


Yes, this is one of the very limiting features of most pipe implementations I have seen - not just in Deepspeed - the new pytorch API is much more flexible wrt to the flexibility of passing the different types of inputs.

When I was trying to implement PP support in transformers I created all kinds of encoding/decoding wrappers to cope with this. It's not pretty.

I think the original reason why there's this issue, is torch.distributed doesn't handle bool tensors very well, but I think it's fixed in pytorch/pytorch#24137 . Maybe since our setup has a higher version, maybe we can remove all code linked to bool tensors?

I agree - this should have been fixed in pt-1.7 according to the link you shared.

Though the problem is that DS has to support older versions, so it's tricky.

@ShadenSmith - do you have any good suggestions?

Copy link
Member Author

@thomasw21 thomasw21 Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh so things didn't work, it seems to deadlock instead of throwing now (or is very slow) .... I'm not sure which/why right now, will investigate some more.

Concerning the hack, it assumes that the last element on input is the attention mask. There's a built in hack right now in deepspeed that implements this hack. In order to trigger it, you have to rename the model class to GPT2ModelPipe.
https://github.com/microsoft/DeepSpeed/blob/big-science/deepspeed/runtime/pipe/engine.py#L854-L860
Basically it converts the tensor to float16, passes it through torch.distributed and then converts it back, in order to allow the engine to pass a bool tensor from one stage to another.

Though now that I think about it, if naming the model GPTModelPipe didn't work I think there's an issue with DS handling bool tensors (on top of pytorch) where the initial buffer doesn't have the correct dtype. Maybe we can pass it when we send meta? However that's a bit out of my expertise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, I missed that it checks for a specific class name to activate different functionality. Thank you for explaining, Thomas.

I hope @ShadenSmith will have a moment to recommend a solution here, since he is the author of that code I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping, and sorry for the trouble with these hacks.

I'm inclined to remove the hacks and just expect folks to either use torch 1.7+, or convert their activations from unsupported types at the model level.

We can support much more flexible message shapes/types with a more robust meta-communication. There some p2p routines that serialize the message with pickle() and then just send that over to exchange the meta-information. Using pickle has a few tradeoffs, including an expensive GPU -> CPU -> GPU transfer for the first activation exchange, and the typical security risks associated with pickling.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but in this particular case it's just a tensor so it should be easy to just allow it if we ask for pt>=1.7

We can support much more flexible message shapes/types with a more robust meta-communication. There some p2p routines that serialize the message with pickle() and then just send that over to exchange the meta-information. Using pickle has a few tradeoffs, including an expensive GPU -> CPU -> GPU transfer for the first activation exchange, and the typical security risks associated with pickling.

That would be great! but we can think of advanced features separately - as I mentioned in the past when porting PP - this would be extremely helpful to anybody wanting to integrated DS's PP into their software.

So just allowing bool tensors is the only showstopper at the moment if I understand it correctly. Thomas please correct me if I'm wrong.

Copy link
Member Author

@thomasw21 thomasw21 Sep 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I'm having a deadlock somewhere currently (or just sometging very slow). But at least no exceptions.

Concerning the upgrading to torch>1.7, I think we just need to update the meta communication, I can try to implement a version on my branch to see if I can remove the hack. Though I'm having another issue (first paragraph) right now and I'm unclear why exactly ....

I think we can just use an id to store which dtype we pass, and store a dictionary mapping each id to a dtype. Using an id would allow us to just pass tensors, similarly to how it uses an id to determine if we communicate tensors or tuple of tensors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @StellaAthena @ontocord as you both expressed interest in deepspeed issues for prefix lm.

@thomasw21
Copy link
Member Author

thomasw21 commented Sep 29, 2021

Small update:

Things I tried to fix PP>1:

  • switch optimizer to Adam instead of FusedAdam. This can be updated in this repo. No success, same symptoms.
  • tried to trigger KeyboardInterrupt in order to see where it deadlocks. I'm having issues viewing it as the launcher just not allows access to subprocess logs, ie it prints out "Killing subprocess {process_ID}".
  • tried switching to gloo. Getting RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. .

@@ -174,3 +174,114 @@ def test_training_all(self):
# test tensorboard (1 file from the first run, plus 1 now)
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")

@unittest.skip("Skip until deepspeed allows to pass tuple of tensors between pipeline stages.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test passes with https://github.com/thomasw21/DeepSpeed/tree/big-science-fix-passing-multiple-tensors/deepspeed I'll wait to merge that PR on the big-science branch in order to un-skip that test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which PR is that in the DS land?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did test with your branch and the test works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I found it - you mentioned it elsewhere: microsoft/DeepSpeed#1400

All is good.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops sorry wrong link. Are you okay merging as is? (we can uncomment the tests as soon as the DS PR is merged + update requirements). It should not break anything (the test_gpt times out, thus unrelated to this PR, I guess I'll create an issue for that)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to wait till the PR on DS side is merged?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you feel that it'll unblock others then it is ok to merge.

but we just don't know yet whether DS is going to accept your proposed changes...

@thomasw21 thomasw21 changed the title WIP: Fix deepspeed prefix-lm Fix deepspeed prefix-lm Oct 5, 2021
@thomasw21 thomasw21 marked this pull request as ready for review October 5, 2021 16:29
@thomasw21
Copy link
Member Author

thomasw21 commented Oct 5, 2021

So this PR is a partial fix of prefix lm, the other part is in deepspeed microsoft/DeepSpeed#1400

I don't think this PR breaks anything, so I think it's ready for review! cc @stas00

Copy link
Member

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - thank you for the additional commentary, Thomas

@thomasw21 thomasw21 changed the base branch from main to master October 10, 2021 17:20
@thomasw21 thomasw21 changed the base branch from master to main October 10, 2021 17:20
@thomasw21 thomasw21 merged commit da31db6 into main Oct 10, 2021
@thomasw21 thomasw21 deleted the thomas/fix_deepspeed_prefix branch July 4, 2022 08:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants