| |
| """ |
| # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """ |
|
|
| from abc import ABC |
| import torch |
| import torch.nn.functional as F |
| from typing import Dict, Optional |
|
|
| import torch.nn as nn |
| from einops import pack, rearrange, repeat |
| from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D |
| from .matcha_transformer import BasicTransformerBlock |
| from omegaconf import DictConfig |
|
|
|
|
| def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
| assert mask.dtype == torch.bool |
| assert dtype in [torch.float32, torch.bfloat16, torch.float16] |
| mask = mask.to(dtype) |
| |
| |
| |
| mask = (1.0 - mask) * torch.finfo(dtype).min |
| return mask |
|
|
|
|
| def subsequent_chunk_mask( |
| size: int, |
| chunk_size: int, |
| num_left_chunks: int = -1, |
| device: torch.device = torch.device("cpu"), |
| ) -> torch.Tensor: |
| """Create mask for subsequent steps (size, size) with chunk size, |
| this is for streaming encoder |
| |
| Args: |
| size (int): size of mask |
| chunk_size (int): size of chunk |
| num_left_chunks (int): number of left chunks |
| <0: use full chunk |
| >=0: use num_left_chunks |
| device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
| |
| Returns: |
| torch.Tensor: mask |
| |
| Examples: |
| >>> subsequent_chunk_mask(4, 2) |
| [[1, 1, 0, 0], |
| [1, 1, 0, 0], |
| [1, 1, 1, 1], |
| [1, 1, 1, 1]] |
| """ |
| |
| |
| pos_idx = torch.arange(size, device=device) |
| block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size |
| ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) |
| return ret |
|
|
| def subsequent_mask( |
| size: int, |
| device: torch.device = torch.device("cpu"), |
| ) -> torch.Tensor: |
| """Create mask for subsequent steps (size, size). |
| |
| This mask is used only in decoder which works in an auto-regressive mode. |
| This means the current step could only do attention with its left steps. |
| |
| In encoder, fully attention is used when streaming is not necessary and |
| the sequence is not long. In this case, no attention mask is needed. |
| |
| When streaming is need, chunk-based attention is used in encoder. See |
| subsequent_chunk_mask for the chunk-based attention mask. |
| |
| Args: |
| size (int): size of mask |
| str device (str): "cpu" or "cuda" or torch.Tensor.device |
| dtype (torch.device): result dtype |
| |
| Returns: |
| torch.Tensor: mask |
| |
| Examples: |
| >>> subsequent_mask(3) |
| [[1, 0, 0], |
| [1, 1, 0], |
| [1, 1, 1]] |
| """ |
| arange = torch.arange(size, device=device) |
| mask = arange.expand(size, size) |
| arange = arange.unsqueeze(-1) |
| mask = mask <= arange |
| return mask |
|
|
|
|
| def add_optional_chunk_mask(xs: torch.Tensor, |
| masks: torch.Tensor, |
| use_dynamic_chunk: bool, |
| use_dynamic_left_chunk: bool, |
| decoding_chunk_size: int, |
| static_chunk_size: int, |
| num_decoding_left_chunks: int, |
| enable_full_context: bool = True): |
| """ Apply optional mask for encoder. |
| |
| Args: |
| xs (torch.Tensor): padded input, (B, L, D), L for max length |
| mask (torch.Tensor): mask for xs, (B, 1, L) |
| use_dynamic_chunk (bool): whether to use dynamic chunk or not |
| use_dynamic_left_chunk (bool): whether to use dynamic left chunk for |
| training. |
| decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's |
| 0: default for training, use random dynamic chunk. |
| <0: for decoding, use full chunk. |
| >0: for decoding, use fixed chunk size as set. |
| static_chunk_size (int): chunk size for static chunk training/decoding |
| if it's greater than 0, if use_dynamic_chunk is true, |
| this parameter will be ignored |
| num_decoding_left_chunks: number of left chunks, this is for decoding, |
| the chunk size is decoding_chunk_size. |
| >=0: use num_decoding_left_chunks |
| <0: use all left chunks |
| enable_full_context (bool): |
| True: chunk size is either [1, 25] or full context(max_len) |
| False: chunk size ~ U[1, 25] |
| |
| Returns: |
| torch.Tensor: chunk mask of the input xs. |
| """ |
| |
| if use_dynamic_chunk: |
| max_len = xs.size(1) |
| if decoding_chunk_size < 0: |
| chunk_size = max_len |
| num_left_chunks = -1 |
| elif decoding_chunk_size > 0: |
| chunk_size = decoding_chunk_size |
| num_left_chunks = num_decoding_left_chunks |
| else: |
| |
| |
| |
| chunk_size = torch.randint(1, max_len, (1, )).item() |
| num_left_chunks = -1 |
| if chunk_size > max_len // 2 and enable_full_context: |
| chunk_size = max_len |
| else: |
| chunk_size = chunk_size % 25 + 1 |
| if use_dynamic_left_chunk: |
| max_left_chunks = (max_len - 1) // chunk_size |
| num_left_chunks = torch.randint(0, max_left_chunks, |
| (1, )).item() |
| chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, |
| num_left_chunks, |
| xs.device) |
| chunk_masks = chunk_masks.unsqueeze(0) |
| chunk_masks = masks & chunk_masks |
| elif static_chunk_size > 0: |
| num_left_chunks = num_decoding_left_chunks |
| chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, |
| num_left_chunks, |
| xs.device) |
| chunk_masks = chunk_masks.unsqueeze(0) |
| chunk_masks = masks & chunk_masks |
| else: |
| chunk_masks = masks |
| return chunk_masks |
|
|
|
|
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
| """Make mask tensor containing indices of padded part. |
| |
| See description of make_non_pad_mask. |
| |
| Args: |
| lengths (torch.Tensor): Batch of lengths (B,). |
| Returns: |
| torch.Tensor: Mask tensor containing indices of padded part. |
| |
| Examples: |
| >>> lengths = [5, 3, 2] |
| >>> make_pad_mask(lengths) |
| masks = [[0, 0, 0, 0 ,0], |
| [0, 0, 0, 1, 1], |
| [0, 0, 1, 1, 1]] |
| """ |
| batch_size = lengths.size(0) |
| max_len = max_len if max_len > 0 else lengths.max().item() |
| seq_range = torch.arange(0, |
| max_len, |
| dtype=torch.int64, |
| device=lengths.device) |
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) |
| seq_length_expand = lengths.unsqueeze(-1) |
| mask = seq_range_expand >= seq_length_expand |
| return mask |
|
|
| |
| class Transpose(torch.nn.Module): |
| def __init__(self, dim0: int, dim1: int): |
| super().__init__() |
| self.dim0 = dim0 |
| self.dim1 = dim1 |
|
|
| def forward(self, x: torch.Tensor): |
| x = torch.transpose(x, self.dim0, self.dim1) |
| return x |
|
|
| class CausalBlock1D(Block1D): |
| def __init__(self, dim: int, dim_out: int): |
| super(CausalBlock1D, self).__init__(dim, dim_out) |
| self.block = torch.nn.Sequential( |
| CausalConv1d(dim, dim_out, 3), |
| Transpose(1, 2), |
| nn.LayerNorm(dim_out), |
| Transpose(1, 2), |
| nn.Mish(), |
| ) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor): |
| output = self.block(x * mask) |
| return output * mask |
|
|
| class CausalResnetBlock1D(ResnetBlock1D): |
| def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): |
| super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) |
| self.block1 = CausalBlock1D(dim, dim_out) |
| self.block2 = CausalBlock1D(dim_out, dim_out) |
|
|
| class CausalConv1d(torch.nn.Conv1d): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int = 1, |
| dilation: int = 1, |
| groups: int = 1, |
| bias: bool = True, |
| padding_mode: str = 'zeros', |
| device=None, |
| dtype=None |
| ) -> None: |
| super(CausalConv1d, self).__init__(in_channels, out_channels, |
| kernel_size, stride, |
| padding=0, dilation=dilation, |
| groups=groups, bias=bias, |
| padding_mode=padding_mode, |
| device=device, dtype=dtype) |
| assert stride == 1 |
| self.causal_padding = (kernel_size - 1, 0) |
|
|
| def forward(self, x: torch.Tensor): |
| x = F.pad(x, self.causal_padding) |
| x = super(CausalConv1d, self).forward(x) |
| return x |
|
|
|
|
| class BASECFM(torch.nn.Module, ABC): |
| def __init__( |
| self, |
| n_feats, |
| cfm_params, |
| n_spks=1, |
| spk_emb_dim=128, |
| ): |
| super().__init__() |
| self.n_feats = n_feats |
| self.n_spks = n_spks |
| self.spk_emb_dim = spk_emb_dim |
| self.solver = cfm_params.solver |
| if hasattr(cfm_params, "sigma_min"): |
| self.sigma_min = cfm_params.sigma_min |
| else: |
| self.sigma_min = 1e-4 |
|
|
| self.estimator = None |
|
|
| @torch.inference_mode() |
| def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): |
| """Forward diffusion |
| |
| Args: |
| mu (torch.Tensor): output of encoder |
| shape: (batch_size, n_feats, mel_timesteps) |
| mask (torch.Tensor): output_mask |
| shape: (batch_size, 1, mel_timesteps) |
| n_timesteps (int): number of diffusion steps |
| temperature (float, optional): temperature for scaling noise. Defaults to 1.0. |
| spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| shape: (batch_size, spk_emb_dim) |
| cond: Not used but kept for future purposes |
| |
| Returns: |
| sample: generated mel-spectrogram |
| shape: (batch_size, n_feats, mel_timesteps) |
| """ |
| z = torch.randn_like(mu) * temperature |
| t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) |
| return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) |
|
|
| def solve_euler(self, x, t_span, mu, mask, spks, cond): |
| """ |
| Fixed euler solver for ODEs. |
| Args: |
| x (torch.Tensor): random noise |
| t_span (torch.Tensor): n_timesteps interpolated |
| shape: (n_timesteps + 1,) |
| mu (torch.Tensor): output of encoder |
| shape: (batch_size, n_feats, mel_timesteps) |
| mask (torch.Tensor): output_mask |
| shape: (batch_size, 1, mel_timesteps) |
| spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| shape: (batch_size, spk_emb_dim) |
| cond: Not used but kept for future purposes |
| """ |
| t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] |
|
|
| |
| |
| sol = [] |
|
|
| for step in range(1, len(t_span)): |
| dphi_dt = self.estimator(x, mask, mu, t, spks, cond) |
|
|
| x = x + dt * dphi_dt |
| t = t + dt |
| sol.append(x) |
| if step < len(t_span) - 1: |
| dt = t_span[step + 1] - t |
|
|
| return sol[-1] |
|
|
| def compute_loss(self, x1, mask, mu, spks=None, cond=None): |
| """Computes diffusion loss |
| |
| Args: |
| x1 (torch.Tensor): Target |
| shape: (batch_size, n_feats, mel_timesteps) |
| mask (torch.Tensor): target mask |
| shape: (batch_size, 1, mel_timesteps) |
| mu (torch.Tensor): output of encoder |
| shape: (batch_size, n_feats, mel_timesteps) |
| spks (torch.Tensor, optional): speaker embedding. Defaults to None. |
| shape: (batch_size, spk_emb_dim) |
| |
| Returns: |
| loss: conditional flow matching loss |
| y: conditional flow |
| shape: (batch_size, n_feats, mel_timesteps) |
| """ |
| b, _, t = mu.shape |
|
|
| |
| t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) |
| |
| z = torch.randn_like(x1) |
|
|
| y = (1 - (1 - self.sigma_min) * t) * z + t * x1 |
| u = x1 - (1 - self.sigma_min) * z |
|
|
| loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( |
| torch.sum(mask) * u.shape[1] |
| ) |
| return loss, y |
|
|
|
|
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
| """Make mask tensor containing indices of padded part. |
| |
| See description of make_non_pad_mask. |
| |
| Args: |
| lengths (torch.Tensor): Batch of lengths (B,). |
| Returns: |
| torch.Tensor: Mask tensor containing indices of padded part. |
| |
| Examples: |
| >>> lengths = [5, 3, 2] |
| >>> make_pad_mask(lengths) |
| masks = [[0, 0, 0, 0 ,0], |
| [0, 0, 0, 1, 1], |
| [0, 0, 1, 1, 1]] |
| """ |
| batch_size = lengths.size(0) |
| max_len = max_len if max_len > 0 else lengths.max().item() |
| seq_range = torch.arange(0, |
| max_len, |
| dtype=torch.int64, |
| device=lengths.device) |
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) |
| seq_length_expand = lengths.unsqueeze(-1) |
| mask = seq_range_expand >= seq_length_expand |
| return mask |
|
|
|
|
| class ConditionalDecoder(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| causal=False, |
| channels=(256, 256), |
| dropout=0.05, |
| attention_head_dim=64, |
| n_blocks=1, |
| num_mid_blocks=2, |
| num_heads=4, |
| act_fn="snake", |
| gradient_checkpointing=True, |
| ): |
| """ |
| This decoder requires an input with the same shape of the target. So, if your text content |
| is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. |
| """ |
| super().__init__() |
| channels = tuple(channels) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.causal = causal |
| self.static_chunk_size = 2 * 25 * 2 |
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| self.time_embeddings = SinusoidalPosEmb(in_channels) |
| time_embed_dim = channels[0] * 4 |
| self.time_mlp = TimestepEmbedding( |
| in_channels=in_channels, |
| time_embed_dim=time_embed_dim, |
| act_fn="silu", |
| ) |
| self.down_blocks = nn.ModuleList([]) |
| self.mid_blocks = nn.ModuleList([]) |
| self.up_blocks = nn.ModuleList([]) |
|
|
| output_channel = in_channels |
| for i in range(len(channels)): |
| input_channel = output_channel |
| output_channel = channels[i] |
| is_last = i == len(channels) - 1 |
| resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ |
| ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) |
| transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=output_channel, |
| num_attention_heads=num_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=act_fn, |
| ) |
| for _ in range(n_blocks) |
| ] |
| ) |
| downsample = ( |
| Downsample1D(output_channel) if not is_last else |
| CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) |
| ) |
| self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) |
|
|
| for _ in range(num_mid_blocks): |
| input_channel = channels[-1] |
| out_channels = channels[-1] |
| resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ |
| ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) |
|
|
| transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=output_channel, |
| num_attention_heads=num_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=act_fn, |
| ) |
| for _ in range(n_blocks) |
| ] |
| ) |
|
|
| self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) |
|
|
| channels = channels[::-1] + (channels[0],) |
| for i in range(len(channels) - 1): |
| input_channel = channels[i] * 2 |
| output_channel = channels[i + 1] |
| is_last = i == len(channels) - 2 |
| resnet = CausalResnetBlock1D( |
| dim=input_channel, |
| dim_out=output_channel, |
| time_emb_dim=time_embed_dim, |
| ) if self.causal else ResnetBlock1D( |
| dim=input_channel, |
| dim_out=output_channel, |
| time_emb_dim=time_embed_dim, |
| ) |
| transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=output_channel, |
| num_attention_heads=num_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=act_fn, |
| ) |
| for _ in range(n_blocks) |
| ] |
| ) |
| upsample = ( |
| Upsample1D(output_channel, use_conv_transpose=True) |
| if not is_last |
| else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) |
| ) |
| self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) |
| self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) |
| self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv1d): |
| nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.GroupNorm): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward(self, x, mask, mu, t, spks=None, cond=None): |
| """Forward pass of the UNet1DConditional model. |
| |
| Args: |
| x (torch.Tensor): shape (batch_size, in_channels, time) |
| mask (_type_): shape (batch_size, 1, time) |
| t (_type_): shape (batch_size) |
| spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. |
| cond (_type_, optional): placeholder for future use. Defaults to None. |
| |
| Raises: |
| ValueError: _description_ |
| ValueError: _description_ |
| |
| Returns: |
| _type_: _description_ |
| """ |
| t = self.time_embeddings(t) |
| t = t.to(x.dtype) |
| t = self.time_mlp(t) |
| x = pack([x, mu], "b * t")[0] |
| mask = mask.to(x.dtype) |
| if spks is not None: |
| spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) |
| x = pack([x, spks], "b * t")[0] |
| if cond is not None: |
| x = pack([x, cond], "b * t")[0] |
|
|
| hiddens = [] |
| masks = [mask] |
| for resnet, transformer_blocks, downsample in self.down_blocks: |
| mask_down = masks[-1] |
| x = resnet(x, mask_down, t) |
| x = rearrange(x, "b c t -> b t c").contiguous() |
| |
| attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) |
| attn_mask = mask_to_bias(attn_mask == 1, x.dtype) |
| for transformer_block in transformer_blocks: |
| if self.gradient_checkpointing and self.training: |
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(transformer_block), |
| x, |
| attn_mask, |
| t, |
| ) |
| else: |
| x = transformer_block( |
| hidden_states=x, |
| attention_mask=attn_mask, |
| timestep=t, |
| ) |
| x = rearrange(x, "b t c -> b c t").contiguous() |
| hiddens.append(x) |
| x = downsample(x * mask_down) |
| masks.append(mask_down[:, :, ::2]) |
| masks = masks[:-1] |
| mask_mid = masks[-1] |
|
|
| for resnet, transformer_blocks in self.mid_blocks: |
| x = resnet(x, mask_mid, t) |
| x = rearrange(x, "b c t -> b t c").contiguous() |
| |
| attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) |
| attn_mask = mask_to_bias(attn_mask == 1, x.dtype) |
| for transformer_block in transformer_blocks: |
| if self.gradient_checkpointing and self.training: |
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(transformer_block), |
| x, |
| attn_mask, |
| t, |
| ) |
| else: |
| x = transformer_block( |
| hidden_states=x, |
| attention_mask=attn_mask, |
| timestep=t, |
| ) |
| x = rearrange(x, "b t c -> b c t").contiguous() |
|
|
| for resnet, transformer_blocks, upsample in self.up_blocks: |
| mask_up = masks.pop() |
| skip = hiddens.pop() |
| x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] |
| x = resnet(x, mask_up, t) |
| x = rearrange(x, "b c t -> b t c").contiguous() |
| |
| attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) |
| attn_mask = mask_to_bias(attn_mask == 1, x.dtype) |
| for transformer_block in transformer_blocks: |
| if self.gradient_checkpointing and self.training: |
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(transformer_block), |
| x, |
| attn_mask, |
| t, |
| ) |
| else: |
| x = transformer_block( |
| hidden_states=x, |
| attention_mask=attn_mask, |
| timestep=t, |
| ) |
| x = rearrange(x, "b t c -> b c t").contiguous() |
| x = upsample(x * mask_up) |
| x = self.final_block(x, mask_up) |
| output = self.final_proj(x * mask_up) |
| return output * mask |
|
|
|
|
| class ConditionalCFM(BASECFM): |
| def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64): |
| super().__init__( |
| n_feats=in_channels, |
| cfm_params=cfm_params, |
| n_spks=n_spks, |
| spk_emb_dim=spk_emb_dim, |
| ) |
| self.t_scheduler = cfm_params.t_scheduler |
| self.training_cfg_rate = cfm_params.training_cfg_rate |
| self.inference_cfg_rate = cfm_params.inference_cfg_rate |
|
|
| @torch.inference_mode() |
| def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): |
| """Forward diffusion |
| |
| Args: |
| mu (torch.Tensor): output of encoder |
| shape: (batch_size, n_feats, mel_timesteps) |
| mask (torch.Tensor): output_mask |
| shape: (batch_size, 1, mel_timesteps) |
| n_timesteps (int): number of diffusion steps |
| temperature (float, optional): temperature for scaling noise. Defaults to 1.0. |
| spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| shape: (batch_size, spk_emb_dim) |
| cond: Not used but kept for future purposes |
| |
| Returns: |
| sample: generated mel-spectrogram |
| shape: (batch_size, n_feats, mel_timesteps) |
| """ |
| z = torch.randn_like(mu) * temperature |
| t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) |
| if self.t_scheduler == 'cosine': |
| t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) |
| return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond) |
|
|
| def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond): |
| """ |
| Fixed euler solver for ODEs. |
| Args: |
| x (torch.Tensor): random noise |
| t_span (torch.Tensor): n_timesteps interpolated |
| shape: (n_timesteps + 1,) |
| mu (torch.Tensor): output of encoder |
| shape: (batch_size, n_feats, mel_timesteps) |
| mask (torch.Tensor): output_mask |
| shape: (batch_size, 1, mel_timesteps) |
| spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| shape: (batch_size, spk_emb_dim) |
| cond: Not used but kept for future purposes |
| """ |
| t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] |
|
|
| |
| |
| sol = [] |
|
|
| for step in range(1, len(t_span)): |
| dphi_dt = estimator(x, mask, mu, t, spks, cond) |
| |
| if self.inference_cfg_rate > 0: |
| cfg_dphi_dt = estimator( |
| x, mask, |
| torch.zeros_like(mu), t, |
| torch.zeros_like(spks) if spks is not None else None, |
| cond=cond |
| ) |
| dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - |
| self.inference_cfg_rate * cfg_dphi_dt) |
| x = x + dt * dphi_dt |
| t = t + dt |
| sol.append(x) |
| if step < len(t_span) - 1: |
| dt = t_span[step + 1] - t |
|
|
| return sol[-1] |
|
|
| def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None): |
| """Computes diffusion loss |
| |
| Args: |
| x1 (torch.Tensor): Target |
| shape: (batch_size, n_feats, mel_timesteps) |
| mask (torch.Tensor): target mask |
| shape: (batch_size, 1, mel_timesteps) |
| mu (torch.Tensor): output of encoder |
| shape: (batch_size, n_feats, mel_timesteps) |
| spks (torch.Tensor, optional): speaker embedding. Defaults to None. |
| shape: (batch_size, spk_emb_dim) |
| |
| Returns: |
| loss: conditional flow matching loss |
| y: conditional flow |
| shape: (batch_size, n_feats, mel_timesteps) |
| """ |
| org_dtype = x1.dtype |
|
|
| b, _, t = mu.shape |
| |
| t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) |
| if self.t_scheduler == 'cosine': |
| t = 1 - torch.cos(t * 0.5 * torch.pi) |
| |
| z = torch.randn_like(x1) |
|
|
| y = (1 - (1 - self.sigma_min) * t) * z + t * x1 |
| u = x1 - (1 - self.sigma_min) * z |
|
|
| |
| if self.training_cfg_rate > 0: |
| cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate |
| mu = mu * cfg_mask.view(-1, 1, 1) |
| if spks is not None: |
| spks = spks * cfg_mask.view(-1, 1) |
| if cond is not None: |
| cond = cond * cfg_mask.view(-1, 1, 1) |
|
|
| pred = estimator(y, mask, mu, t.squeeze(), spks, cond) |
| pred = pred.float() |
| u = u.float() |
| loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) |
| loss = loss.to(org_dtype) |
| return loss, y |
|
|