-
Notifications
You must be signed in to change notification settings - Fork 84
Description
Thanks a lot for your paper and code!
In your implementation, you didn't set attention mask for text sequence both in textTransformer layers and LinearTemporalCrossAttention layers, why it didn't cause any influence? Below is the related code.
def encode_text(self, text, device):
with torch.no_grad():
text = clip.tokenize(text, truncate=True).to(device)
x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, latent_dim]
x = x + self.clip.positional_embedding.type(self.clip.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip.transformer(x)
x = self.clip.ln_final(x).type(self.clip.dtype)
# T, B, D
x = self.text_pre_proj(x)
xf_out = self.textTransEncoder(x)
xf_out = self.text_ln(xf_out)
xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])])
# B, T, D
xf_out = xf_out.permute(1, 0, 2)
return xf_proj, xf_out
class LinearTemporalCrossAttention(nn.Module):\
def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.text_norm = nn.LayerNorm(text_latent_dim)
self.query = nn.Linear(latent_dim, latent_dim)
self.key = nn.Linear(text_latent_dim, latent_dim)
self.value = nn.Linear(text_latent_dim, latent_dim)
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x, xf, emb):
"""
x: B, T, D
xf: B, N, L
"""
B, T, D = x.shape
N = xf.shape[1]
H = self.num_head
# B, T, D
query = self.query(self.norm(x))
# B, N, D
key = self.key(self.text_norm(xf))
query = F.softmax(query.view(B, T, H, -1), dim=-1)
key = F.softmax(key.view(B, N, H, -1), dim=1)
# B, N, H, HD
value = self.value(self.text_norm(xf)).view(B, N, H, -1)
# B, H, HD, HD
attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
y = x + self.proj_out(y, emb)
return y