반응형
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
class LineEmbedding(nn.Module):
def __init__(self, line_length=256, in_channels=32, emb_size=768):
super().__init__()
self.projection = nn.Sequential(
Rearrange('b (l c) -> b l c', c=in_channels),
nn.Linear(in_channels, emb_size)
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.positions = nn.Parameter(torch.randn(line_length+1 , emb_size))
def forward(self, line_feats):
num_lf, lf_dim = line_feats.shape # N
line_embeddings = self.projection(line_feats)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=num_lf)
line_embeddings = torch.cat([cls_tokens, line_embeddings], dim=1)
line_embeddings += self.positions
return line_embeddings # N, 256+1, emb_size
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size=512, num_heads=8, dropout=0):
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.proejction = nn.Linear(emb_size, emb_size)
def forward(self, x):
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
if __name__=="__main__":
line_feats = torch.zeros([16, 256*32])
LE = LineEmbedding(line_length=256, in_channels=32, emb_size=768)
line_embeddings = LE(line_feats)
print("")
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
반응형