Skip to content

Commit

Permalink
Merge pull request #49 from xiangking/develop
Browse files Browse the repository at this point in the history
bug修复
  • Loading branch information
xiangking committed Apr 9, 2022
2 parents f8aff87 + 06b09fd commit 90debc3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ark_nlp/nn/layer/global_pointer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def forward(self, inputs, mask=None):
# RoPE编码
if self.RoPE:
pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
cos_pos = pos[..., None, 1::2].repeat(1, 1, 1, 2)
sin_pos = pos[..., None, ::2].repeat(1, 1, 1, 2)
cos_pos = pos[..., None, 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos[..., None, ::2].repeat_interleave(2, dim=-1)
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 4)
qw2 = torch.reshape(qw2, qw.shape)
qw = qw * cos_pos + qw2 * sin_pos
Expand Down Expand Up @@ -112,8 +112,8 @@ def forward(self, inputs, mask=None):
# RoPE编码
if self.RoPE:
pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
cos_pos = pos[..., 1::2].repeat(1, 1, 2)
sin_pos = pos[..., ::2].repeat(1, 1, 2)
cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
qw2 = torch.reshape(qw2, qw.shape)
qw = qw * cos_pos + qw2 * sin_pos
Expand Down

0 comments on commit 90debc3

Please sign in to comment.