카테고리 없음

tmp

침닦는수건 2022. 8. 23. 18:31
반응형
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

class LineEmbedding(nn.Module):
    def __init__(self, line_length=256, in_channels=32, emb_size=768):
        super().__init__()
        self.projection = nn.Sequential(
            Rearrange('b (l c) -> b l c', c=in_channels),
            nn.Linear(in_channels, emb_size)
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn(line_length+1 , emb_size))

    def forward(self, line_feats):
        num_lf, lf_dim = line_feats.shape # N
        line_embeddings = self.projection(line_feats)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=num_lf)
        line_embeddings = torch.cat([cls_tokens, line_embeddings], dim=1)
        line_embeddings += self.positions
        return line_embeddings # N, 256+1, emb_size

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size=512, num_heads=8, dropout=0):
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.proejction = nn.Linear(emb_size, emb_size)
        
    def forward(self, x):
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)



if __name__=="__main__":

    line_feats = torch.zeros([16, 256*32])
    LE = LineEmbedding(line_length=256, in_channels=32, emb_size=768)
    line_embeddings = LE(line_feats)
    print("")

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

반응형