|
|
| from typing import List, Optional, Tuple, Union |
| import torch, math |
| import torch.utils.checkpoint |
| from torch import nn |
| import transformers |
| from flash_attn import flash_attn_varlen_func |
| from transformers.activations import ACT2FN |
| from PIL import Image |
| import io, fire |
| from torch.nn import functional as F |
|
|
| class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.config_attn_implementation = 'flash_attention_2' |
| self.gradient_checkpointing = True |
| self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint |
| self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2 |
| del self.merger |
| |
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| grid_thw: torch.Tensor, |
| ): |
| hidden_states = pixel_values.to(self.get_dtype()) |
| grid_thw = grid_thw.to(pixel_values.device) |
| |
| hidden_states = self.patch_embed(hidden_states) |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) |
|
|
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, dtype=torch.int32 |
| ) |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
| for blk in self.blocks: |
| if self.gradient_checkpointing and self.training: |
| hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb) |
| else: |
| hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) |
|
|
| return hidden_states |
| |
| @torch.no_grad() |
| def fake_input(self, device): |
| merge_size = max(self.merge_size, self.config.spatial_merge_size) |
| fake_image = torch.zeros([ |
| 1, |
| self.config.temporal_patch_size, |
| 3, |
| merge_size // self.config.spatial_merge_size, |
| self.config.spatial_merge_size, |
| self.config.patch_size, |
| merge_size // self.config.spatial_merge_size, |
| self.config.spatial_merge_size, |
| self.config.patch_size, |
| ], dtype=torch.float32, device=device) |
| patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) |
| flatten_patches = patches.reshape( |
| merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size |
| ) |
| return [flatten_patches], [(1, merge_size, merge_size)], [1] |
|
|
|
|
| class OmniVisualBridge(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2 |
| self.hidden_size = config.embed_dim * (self.merge_size**2) |
| self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(self.hidden_size, self.hidden_size), |
| nn.GELU(), |
| nn.Linear(self.hidden_size, config.hidden_size), |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) |
| return x |
|
|
|
|
| if __name__ == '__main__': |
| fire.Fire() |
| |