xiaohei66 RaushanTurganbay HF Staff commited on
Commit
a1e4dc4
·
1 Parent(s): 2b77538

Update modeling_paddleocr_vl.py (#91)

Browse files

- Update modeling_paddleocr_vl.py (dd40b2ca2dd01bb8076495edd7b7468e39816257)


Co-authored-by: Raushan Turganbay <RaushanTurganbay@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_paddleocr_vl.py +8 -172
modeling_paddleocr_vl.py CHANGED
@@ -27,11 +27,10 @@ from transformers.activations import ACT2FN, GELUActivation
27
  from transformers.cache_utils import (
28
  Cache,
29
  DynamicCache,
30
- SlidingWindowCache,
31
- StaticCache,
32
  )
33
  from transformers.generation import GenerationMixin
34
  from transformers.integrations import use_kernel_forward_from_hub
 
35
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
  from transformers.modeling_layers import GradientCheckpointingLayer
37
  from transformers.modeling_outputs import (
@@ -604,12 +603,13 @@ class Ernie4_5Model(Ernie4_5PreTrainedModel):
604
  elif position_ids.dim() == 2:
605
  position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
606
 
607
- causal_mask = self._update_causal_mask(
608
- attention_mask,
609
- inputs_embeds,
610
- cache_position,
611
- past_key_values,
612
- output_attentions,
 
613
  )
614
 
615
  hidden_states = inputs_embeds
@@ -632,170 +632,6 @@ class Ernie4_5Model(Ernie4_5PreTrainedModel):
632
  past_key_values=past_key_values,
633
  )
634
 
635
- def _update_causal_mask(
636
- self,
637
- attention_mask: torch.Tensor,
638
- input_tensor: torch.Tensor,
639
- cache_position: torch.Tensor,
640
- past_key_values: Cache,
641
- output_attentions: bool = False,
642
- ):
643
- if self.config._attn_implementation == "flash_attention_2":
644
- if attention_mask is not None and past_key_values is not None:
645
- is_padding_right = (
646
- attention_mask[:, -1].sum().item() != input_tensor.size()[0]
647
- )
648
- if is_padding_right:
649
- raise ValueError
650
- if attention_mask is not None and 0.0 in attention_mask:
651
- return attention_mask
652
- return None
653
-
654
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
655
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
656
- # to infer the attention mask.
657
- past_seen_tokens = (
658
- past_key_values.get_seq_length() if past_key_values is not None else 0
659
- )
660
- using_static_cache = isinstance(past_key_values, StaticCache)
661
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
662
-
663
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
664
- if (
665
- self.config._attn_implementation == "sdpa"
666
- and not (using_static_cache or using_sliding_window_cache)
667
- and not output_attentions
668
- ):
669
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
670
- attention_mask,
671
- inputs_embeds=input_tensor,
672
- past_key_values_length=past_seen_tokens,
673
- sliding_window=self.config.sliding_window,
674
- is_training=self.training,
675
- ):
676
- return None
677
-
678
- dtype, device = input_tensor.dtype, input_tensor.device
679
- min_dtype = torch.finfo(dtype).min
680
- sequence_length = input_tensor.shape[1]
681
- # SlidingWindowCache or StaticCache
682
- if using_sliding_window_cache or using_static_cache:
683
- target_length = past_key_values.get_max_cache_shape()
684
- # DynamicCache or no cache
685
- else:
686
- target_length = (
687
- attention_mask.shape[-1]
688
- if isinstance(attention_mask, torch.Tensor)
689
- else past_seen_tokens + sequence_length + 1
690
- )
691
-
692
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
693
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
694
- attention_mask,
695
- sequence_length=sequence_length,
696
- target_length=target_length,
697
- dtype=dtype,
698
- device=device,
699
- cache_position=cache_position,
700
- batch_size=input_tensor.shape[0],
701
- config=self.config,
702
- past_key_values=past_key_values,
703
- )
704
-
705
- if (
706
- self.config._attn_implementation == "sdpa"
707
- and attention_mask is not None
708
- and attention_mask.device.type in ["cuda", "xpu"]
709
- and not output_attentions
710
- ):
711
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
712
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
713
- # Details: https://github.com/pytorch/pytorch/issues/110213
714
- causal_mask = AttentionMaskConverter._unmask_unattended(
715
- causal_mask, min_dtype
716
- )
717
-
718
- return causal_mask
719
-
720
- @staticmethod
721
- def _prepare_4d_causal_attention_mask_with_cache_position(
722
- attention_mask: torch.Tensor,
723
- sequence_length: int,
724
- target_length: int,
725
- dtype: torch.dtype,
726
- device: torch.device,
727
- cache_position: torch.Tensor,
728
- batch_size: int,
729
- config: PaddleOCRVLConfig,
730
- past_key_values: Cache,
731
- ):
732
- """
733
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
734
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
735
-
736
- Args:
737
- attention_mask (`torch.Tensor`):
738
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
739
- sequence_length (`int`):
740
- The sequence length being processed.
741
- target_length (`int`):
742
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
743
- dtype (`torch.dtype`):
744
- The dtype to use for the 4D attention mask.
745
- device (`torch.device`):
746
- The device to place the 4D attention mask on.
747
- cache_position (`torch.Tensor`):
748
- Indices depicting the position of the input sequence tokens in the sequence.
749
- batch_size (`torch.Tensor`):
750
- Batch size.
751
- config (`PaddleOCRVLConfig`):
752
- The model's configuration class
753
- past_key_values (`Cache`):
754
- The cache class that is being used currently to generate
755
- """
756
- if attention_mask is not None and attention_mask.dim() == 4:
757
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
758
- causal_mask = attention_mask
759
- else:
760
- min_dtype = torch.finfo(dtype).min
761
- causal_mask = torch.full(
762
- (sequence_length, target_length),
763
- fill_value=min_dtype,
764
- dtype=dtype,
765
- device=device,
766
- )
767
- diagonal_attend_mask = torch.arange(
768
- target_length, device=device
769
- ) > cache_position.reshape(-1, 1)
770
- if config.sliding_window is not None:
771
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
772
- # the check is needed to verify is current checkpoint was trained with sliding window or not
773
- if (
774
- not isinstance(past_key_values, SlidingWindowCache)
775
- or sequence_length > target_length
776
- ):
777
- sliding_attend_mask = torch.arange(
778
- target_length, device=device
779
- ) <= (cache_position.reshape(-1, 1) - config.sliding_window)
780
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
781
- causal_mask *= diagonal_attend_mask
782
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
783
- if attention_mask is not None:
784
- causal_mask = (
785
- causal_mask.clone()
786
- ) # copy to contiguous memory for in-place edit
787
- if attention_mask.shape[-1] > target_length:
788
- attention_mask = attention_mask[:, :target_length]
789
- mask_length = attention_mask.shape[-1]
790
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
791
- :, None, None, :
792
- ].to(causal_mask.device)
793
- padding_mask = padding_mask == 0
794
- causal_mask[:, :, :, :mask_length] = causal_mask[
795
- :, :, :, :mask_length
796
- ].masked_fill(padding_mask, min_dtype)
797
- return causal_mask
798
-
799
 
800
  class Ernie4_5ForCausalLM(Ernie4_5PreTrainedModel, GenerationMixin):
801
  _tied_weights_keys = ["lm_head.weight"]
 
27
  from transformers.cache_utils import (
28
  Cache,
29
  DynamicCache,
 
 
30
  )
31
  from transformers.generation import GenerationMixin
32
  from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.masking_utils import create_causal_mask
34
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
35
  from transformers.modeling_layers import GradientCheckpointingLayer
36
  from transformers.modeling_outputs import (
 
603
  elif position_ids.dim() == 2:
604
  position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
605
 
606
+ causal_mask = create_causal_mask(
607
+ config=self.config,
608
+ inputs_embeds=inputs_embeds,
609
+ attention_mask=attention_mask,
610
+ past_key_values=past_key_values,
611
+ position_ids=position_ids,
612
+ cache_position=cache_position,
613
  )
614
 
615
  hidden_states = inputs_embeds
 
632
  past_key_values=past_key_values,
633
  )
634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
  class Ernie4_5ForCausalLM(Ernie4_5PreTrainedModel, GenerationMixin):
637
  _tied_weights_keys = ["lm_head.weight"]