Skip to content

Instantly share code, notes, and snippets.

@PhilipGAQ
Created August 20, 2024 10:55
Show Gist options
  • Save PhilipGAQ/054a24cde3811c696b742747d7772857 to your computer and use it in GitHub Desktop.
Save PhilipGAQ/054a24cde3811c696b742747d7772857 to your computer and use it in GitHub Desktop.
model structure
import torch
import torch.nn as nn
import math
# 定义 Positional Encoding 层
class PositionalEncoding(nn.Module):
def __init__(self, embedding_dim, max_len=5000):
super(PositionalEncoding, self).__init__()
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim))
pe = torch.zeros(max_len, 1, embedding_dim)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.pe = pe
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return x
# 定义 TransformerEncoderLayer 层
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + src2
src = self.layer_norm(src)
src2 = self.feed_forward(src)
src = src + src2
src = self.layer_norm(src)
return src
# 定义 ContextIntegrationModel
class ContextIntegrationModel(nn.Module):
def __init__(self, embedding_dim, num_heads, num_layers, dropout=0.1):
super(ContextIntegrationModel, self).__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
self.pos_encoder = PositionalEncoding(embedding_dim)
self.encoder_layer = TransformerEncoderLayer(embedding_dim, num_heads)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.new_embedding_generator = nn.Linear(embedding_dim, embedding_dim)
self.layer_norm = nn.LayerNorm(embedding_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, chunks_embeddings):
batch_size = chunks_embeddings.size(0)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
chunks_embeddings = torch.cat((cls_tokens, chunks_embeddings), dim=1)
chunks_embeddings = self.pos_encoder(chunks_embeddings.permute(1, 0, 2))
transformed_embeddings = self.transformer_encoder(chunks_embeddings)
transformed_embeddings = transformed_embeddings.permute(1, 0, 2)
transformed_embeddings = self.layer_norm(transformed_embeddings + chunks_embeddings.permute(1, 0, 2))
transformed_embeddings = self.dropout(transformed_embeddings)
doc_embedding = transformed_embeddings[:, 0]
new_embeddings = self.new_embedding_generator(transformed_embeddings[:, 1:])
doc_embedding = torch.nn.functional.normalize(doc_embedding, p=2, dim=-1)
new_embeddings = torch.nn.functional.normalize(new_embeddings, p=2, dim=-1)
return doc_embedding, new_embeddings
# 定义模型
class TransformerModel(nn.Module):
def __init__(self, embedding_dim, num_heads=8, num_layers=1):
super(TransformerModel, self).__init__()
self.context_integration = ContextIntegrationModel(embedding_dim, num_heads, num_layers)
def forward(self, chunks_embedding):
transformed_embeddings = self.context_integration(chunks_embedding)
return transformed_embeddings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment