diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 87a481a..c1d7db3 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -358,27 +358,6 @@ class Tacotron(nn.Module): @r.setter def r(self, value): self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) - - def compute_gst(self, inputs, style_input, speaker_embedding=None): - """ Compute global style token """ - device = inputs.device - if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - if speaker_embedding is not None: - query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) - - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token, v_amplifier in style_input.items(): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * v_amplifier - elif style_input is None: - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - else: - gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable - inputs = self._concat_speaker_embedding(inputs, gst_outputs) - return inputs @staticmethod def _concat_speaker_embedding(outputs, speaker_embeddings): @@ -486,7 +465,7 @@ class Tacotron(nn.Module): speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32) speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device) else: - speaker_embedding_style = torch.zeros(2, 1, self.speaker_embedding_size).to(device) + speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) style_embed = self.gst(speaker_embedding_style, speaker_embedding) encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # style_embed = style_embed.expand_as(encoder_seq)