上次跟着基因组学深度学习入门指南跑通了一个基础的CNN模型,效果还不错。但听说现在NLP领域Attention机制大杀四方,我就在想,这玩意儿能不能用到基因组序列分析上?毕竟DNA序列也是一种“语言”。

于是,我决定动手试试:把Attention机制加到CNN里,看看能不能提升性能,顺便搞懂它到底在关注什么。

这份笔记记录了我的探索过程,包括核心代码实现、遇到的坑(Dimension mismatch, RuntimeErrors…)以及最终的实验结果。


0. 任务回顾

还是那个经典的转录因子结合位点识别任务:

  • 输入:50bp DNA序列 (A, C, G, T)
  • 输出:是否有结合位点 (0/1)
  • 核心难点:Motif (CGACCGAACTCC) 可能出现在序列的任何位置。

传统CNN用卷积核去“扫描”序列,这有点像拿着放大镜一段一段看。而Attention机制据说能让模型拥有“全局视野”,一眼看到关键区域。


1. 费曼学习法:怎么理解Attention?

在写代码之前,我先尝试用费曼技巧把Attention机制给自己讲清楚。

想象我是一个调酒师(Attention模块),顾客(Decoder状态)想要一杯特定的酒。

  1. Query (顾客的要求): 这里指的是我们目前的解码状态,或者说我们“想找什么”。在我的模型里,我把CNN提取的特征图的平均值作为Query,代表“整条序列的大致风貌”。
  2. Key (酒瓶标签): 吧台后陈列着各种酒(Encoder输出的特征序列),每瓶酒都有标签。
  3. Value (酒液): 瓶子里的酒本身。在很多简单Attention里,Key和Value是同一个东西,就是CNN在每个位置提取的特征。

调酒过程 (计算Attention):

  1. 匹配 (Score): 我看了一眼顾客的要求 (Query),然后扫视一圈酒瓶标签 (Keys),计算每瓶酒跟顾客要求的匹配度。
  2. 加权 (Softmax): 匹配度高的,我就多倒点;匹配度低的,就少倒点或者不倒。这个比例就是Attention Weights
  3. 混合 (Context Vector): 把倒出来的酒混合在一起,这就得到了一杯“特调鸡尾酒” (Context Vector)。

这就解释了为什么Attention能捕捉全局信息:它根据当前的需要,动态地从所有输入位置中加权提取信息,而不是死板地只看局部。


2. 核心代码实现

为了保证可复现性,我先把数据生成的代码贴在这里。它负责生成包含特定 Motif 的合成 DNA 序列。

2.0 数据生成

# --- 1. Data Generation ---
def generate_data(num_samples=2000, seq_len=50, motif="CGACCGAACTCC"):
    X = []
    y = []
    base_to_int = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    
    for _ in range(num_samples):
        seq_int = np.random.randint(0, 4, seq_len)
        label = 0
        if np.random.rand() > 0.5:
            start_idx = np.random.randint(0, seq_len - len(motif))
            for i, char in enumerate(motif):
                seq_int[start_idx + i] = base_to_int[char]
            label = 1
            
        seq_onehot = np.zeros((4, seq_len))
        for i, val in enumerate(seq_int):
            seq_onehot[val, i] = 1
            
        X.append(seq_onehot)
        y.append(label)
        
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)

2.1 简单的Attention模块

这是我参考教程手搓的一个简单Attention层。写的时候最头疼的就是维度变换,我特意加了详细注释提醒自己。

class SimpleAttention(nn.Module):
    """简单的注意力模块 - 想象成那个调酒师"""
    def __init__(self, hidden_size):
        super(SimpleAttention, self).__init__()
        self.hidden_size = hidden_size
        
        # 这个线性层用来计算匹配分数
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))
 
    def forward(self, decoder_state, encoder_outputs):
        """
        decoder_state: [batch, hidden] - 顾客的要求 (Query)
        encoder_outputs: [batch, seq_len, hidden] - 所有的酒 (Keys/Values)
        """
        batch_size = encoder_outputs.size(0)
        seq_len = encoder_outputs.size(1)
 
        # 1. 扩充Query维度,为了能跟每一个Key进行拼接
        # [batch, hidden] -> [batch, seq_len, hidden]
        decoder_state_repeated = decoder_state.unsqueeze(1).repeat(1, seq_len, 1)
 
        # 2. 拼接 Query 和 Key
        combined = torch.cat((decoder_state_repeated, encoder_outputs), dim=2)
 
        # 3. 计算能量得分 (Energy)
        # 这里的操作有点像把它们放进搅拌机打一下
        energy = torch.tanh(self.attn(combined))  # [batch, seq_len, hidden]
        energy = energy.permute(0, 2, 1)           # [batch, hidden, seq_len]
 
        # 4. 计算注意力权重 (Weights)
        # v 是一个可学习的参数,相当于调酒师的个人偏好
        v = self.v.repeat(batch_size, 1).unsqueeze(1)  # [batch, 1, hidden]
        
        # torch.bmm 是 batch matrix multiplication,批量矩阵乘法
        # 这里就是在计算每个位置的得分
        attention_scores = torch.bmm(v, energy).squeeze(1)  # [batch, seq_len]
        attention_weights = F.softmax(attention_scores, dim=1)  # 归一化,和为1
 
        # 5. 加权求和得到 Context Vector
        # [batch, 1, seq_len] * [batch, seq_len, hidden] -> [batch, 1, hidden]
        attention_weights = attention_weights.unsqueeze(1)
        context = torch.bmm(attention_weights, encoder_outputs)
        
        return context.squeeze(1), attention_weights.squeeze(1)

2.2 组装:CNN + Attention

接下来把这个模块插到CNN后面。

踩坑记录 1: 在定义 self.fc1 时,我一开始写成了 nn.Linear(self.conv_output_dim, 32)。结果训练时报错维度不匹配。 原因:我的 self.fc2 输入维度是 32(Attention模式下),它期望接收 CNN特征(16) + Attention特征(16)。 修复:把 self.fc1 输出改为 16,加上 self.attn_fc 输出的 16,刚好凑成 32。

class GenomicsCNNWithAttention(nn.Module):
    def __init__(self, use_attention=True):
        super(GenomicsCNNWithAttention, self).__init__()
        self.use_attention = use_attention
 
        # CNN部分
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=32, kernel_size=12)
        self.pool = nn.MaxPool1d(kernel_size=4)
        self.conv_output_dim = 32 * 9
        self.hidden_size = 32
 
        # Attention部分
        if use_attention:
            self.attention = SimpleAttention(self.hidden_size)
            self.attn_fc = nn.Linear(self.hidden_size, 16) # Attention特征映射到16维
 
        # 全连接层
        # 这里之前踩过坑,维度要对齐
        self.fc1 = nn.Linear(self.conv_output_dim, 16) # CNN特征映射到16维
        # 最终融合:16(CNN) + 16(Attention) = 32
        self.fc2 = nn.Linear(32 if use_attention else 16, 1)
 
    def forward(self, x):
        # ... (CNN前向传播) ...
        conv_out = F.relu(self.conv1(x))
        conv_out = self.pool(conv_out) 
 
        if self.use_attention:
            # 准备Attention的输入
            # permute是因为Linear层期望特征在最后一维
            encoder_outputs = conv_out.permute(0, 2, 1) 
            # 用平均值作为Query
            decoder_state = conv_out.mean(dim=2) 
            
            # 召唤调酒师!
            context, attention_weights = self.attention(decoder_state, encoder_outputs)
            self.attention_weights = attention_weights # 存下来,后面画图要用
            
            attn_features = F.relu(self.attn_fc(context))
 
        # ... (特征融合与分类) ...
        conv_flat = conv_out.view(conv_out.size(0), -1)
        cnn_features = F.relu(self.fc1(conv_flat))
 
        if self.use_attention:
            combined = torch.cat([cnn_features, attn_features], dim=1)
            output = self.fc2(combined)
        else:
            output = self.fc2(cnn_features)
 
        return output

3. 跑个分看看

为了验证效果,我编写了训练循环,并对比了两个模型的表现。代码如下:

3.1 训练与评估代码

# --- 3. Training & Evaluation Helper ---
def train_model(model, train_loader, test_loader, epochs=30):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    train_acc_history = []
    test_acc_history = []
    
    for epoch in range(epochs):
        model.train()
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
            
            predicted = (torch.sigmoid(outputs.squeeze()) > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_acc = 100 * correct / total
        train_acc_history.append(train_acc)
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                predicted = (torch.sigmoid(outputs.squeeze()) > 0.5).float()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        test_acc = 100 * correct / total
        test_acc_history.append(test_acc)
        
        if (epoch+1) % 5 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
            
    return train_acc_history, test_acc_history

3.2 运行实验与绘图

这里我同时跑了基础CNN和CNN+Attention两个模型,Epoch设为30。

# --- 3. Main Execution ---
print("Generating Data...")
X, y = generate_data()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
 
train_dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
test_dataset = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
 
print("\nTraining Basic CNN...")
cnn_model = GenomicsCNNWithAttention(use_attention=False)
cnn_train_acc, cnn_test_acc = train_model(cnn_model, train_loader, test_loader)
 
print("\nTraining CNN + Attention...")
attn_model = GenomicsCNNWithAttention(use_attention=True)
attn_train_acc, attn_test_acc = train_model(attn_model, train_loader, test_loader)
 
# --- 5. Visualization ---
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(cnn_test_acc, label='Basic CNN', linestyle='--')
plt.plot(attn_test_acc, label='CNN + Attention')
plt.title('Test Accuracy Comparison')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
 
# Comparison of convergence speed (epochs to reach 95%)
cnn_95 = next((i for i, x in enumerate(cnn_test_acc) if x >= 95), 30)
attn_95 = next((i for i, x in enumerate(attn_test_acc) if x >= 95), 30)
 
plt.subplot(1, 2, 2)
bars = plt.bar(['Basic CNN', 'CNN + Attention'], [cnn_95, attn_95], color=['gray', 'orange'])
plt.title('Convergence Speed (Epochs to 95% Acc)')
plt.ylabel('Epochs')
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height}', ha='center', va='bottom')
 
plt.tight_layout()
plt.show()

结果图表

我的发现

  1. 收敛速度:CNN+Attention 简直是“光速”收敛!只用了 1 个 Epoch 就达到了 95% 以上的准确率,而基础 CNN 用了 4 个。
  2. 最终性能:Attention 版本稳稳地拿到了 100% 的准确率(毕竟是模拟数据),基础 CNN 稍微差一点点。

4. 它是怎么做到的?(可解释性分析)

模型效果好是好事,但我更关心它到底学到了什么

4.1 Saliency Map 对比

Saliency Map 告诉我们输入序列中哪些碱基对输出结果贡献最大。为了对比两个模型,我选取了一个包含 Motif 的样本,分别计算了它们的 Saliency Map。

# --- 4. Saliency Map Analysis ---
def compute_saliency_map(model, input_seq):
    model.eval()
    input_seq.requires_grad_()
    
    output = model(input_seq.unsqueeze(0))
    # We want to maximize the output score (probability of being positive)
    output.backward()
    
    saliency = input_seq.grad.data.abs().max(dim=0)[0] # Take max across channels (A,C,G,T)
    return saliency
 
# Select a positive sample
pos_idx = np.where(y_test == 1)[0][0]
sample_seq = torch.tensor(X_test[pos_idx], dtype=torch.float32)
 
# Compute saliency
# Note: Remove torch.no_grad() context if you are running this interactively after eval
cnn_saliency = compute_saliency_map(cnn_model, sample_seq.clone())
attn_saliency = compute_saliency_map(attn_model, sample_seq.clone())
 
# Plot
plt.figure(figsize=(12, 4))
plt.subplot(2, 1, 1)
plt.bar(range(50), cnn_saliency)
plt.title('Saliency Map: Basic CNN')
plt.subplot(2, 1, 2)
plt.bar(range(50), attn_saliency, color='orange')
plt.title('Saliency Map: CNN + Attention')
plt.tight_layout()
plt.show()

踩坑记录 2: 在计算 Saliency Map 时,我遇到了 RuntimeError: element 0 of tensors does not require grad原因:我傻乎乎地在 with torch.no_grad(): 块里调用了需要梯度的 backward()修复:把这部分代码移出 no_grad 上下文。

从图中可以看到,两个模型都关注到了 Motif 所在的区域,但 Attention 模型的关注点似乎更集中、更干净一些。

4.2 偷看Attention权重

这是最精彩的部分。我们把 attention_weights 提取出来画个热图,就像**间谍(Spy)**窃取了调酒师的配方表。

# --- 4. Attention Weight Analysis ---
def analyze_attention(model, dataset, num_samples=100):
    model.eval()
    motif_attention = []
    non_motif_attention = []
    
    with torch.no_grad():
        for i in range(num_samples):
            seq, label = dataset[i]
            if label == 1: # Only analyze positive samples
                output = model(seq.unsqueeze(0))
                attn_weights = model.attention_weights.squeeze().cpu().numpy()
                
                # Interpolate attention weights to match sequence length (50)
                # attn_weights shape: (9,) -> (50,)
                attn_tensor = torch.tensor(attn_weights).view(1, 1, -1)
                attn_upsampled = F.interpolate(attn_tensor, size=50, mode='linear', align_corners=False)
                attn_upsampled = attn_upsampled.squeeze().numpy()
                
                # We need to find where the motif is to separate attention
                # In a real scenario, we might not know, but here we generated the data
                # For simplicity in this "peek", let's just plot the heatmap for one sample first
                pass
 
# Plot Heatmap for one sample
plt.figure(figsize=(10, 2))
# Re-run forward to get weights for the specific sample used in Saliency Map
_ = attn_model(sample_seq.unsqueeze(0))
attn_weights = attn_model.attention_weights.squeeze().detach().numpy()
 
# Upsample for visualization
attn_tensor = torch.tensor(attn_weights).view(1, 1, -1)
attn_upsampled = F.interpolate(attn_tensor, size=50, mode='linear', align_corners=False).squeeze().numpy()
 
sns.heatmap([attn_upsampled], cmap='Reds', cbar=True)
plt.title('Attention Weights Distribution (Upsampled)')
plt.xlabel('Sequence Position')
plt.yticks([])
plt.show()

看到那些深红色的色块了吗?它们几乎完美地覆盖了 Motif (CGACCGAACTCC) 的位置!

踩坑记录 3: 在统计 Motif 区域的平均注意力时,我遇到了 RuntimeWarning: Mean of empty sliceValueError原因:Attention 后的序列长度被池化成了 9,而原始序列是 50。直接切片会导致索引越界或切空。 修复:我引入了 F.interpolate(线性插值),把 9 维的注意力权重平滑地放大回 50 维,然后再跟原始序列对齐。

# --- 4. Statistical Analysis of Attention on Motif ---
# Let's verify if attention really focuses on the motif across many samples
motif_scores = []
background_scores = []
 
# Re-generate some data to know exact motif positions for validation
# Or just rely on the fact that we know the motif string
motif_pattern = "CGACCGAACTCC"
base_to_int = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
 
with torch.no_grad():
    for i in range(100):
        if y_test[i] == 1:
            seq = X_test[i]
            # Find motif start index in this sequence
            # (Simple search for exact match since we generated it without mutation)
            # Reconstruct sequence string
            seq_str = ""
            for col in seq.T:
                idx = np.argmax(col)
                seq_str += "ACGT"[idx]
            
            start_idx = seq_str.find(motif_pattern)
            if start_idx != -1:
                _ = attn_model(torch.tensor(seq).unsqueeze(0).float())
                w = attn_model.attention_weights.squeeze().numpy()
                
                # Upsample
                w_tensor = torch.tensor(w).view(1, 1, -1)
                w_up = F.interpolate(w_tensor, size=50, mode='linear', align_corners=False).squeeze().numpy()
                
                # Extract scores
                motif_score = w_up[start_idx : start_idx+len(motif_pattern)].mean()
                
                # Background score (rest of the sequence)
                mask = np.ones(50, dtype=bool)
                mask[start_idx : start_idx+len(motif_pattern)] = False
                bg_score = w_up[mask].mean()
                
                motif_scores.append(motif_score)
                background_scores.append(bg_score)
 
plt.figure(figsize=(6, 5))
plt.bar(['Motif', 'Background'], [np.mean(motif_scores), np.mean(background_scores)], color=['red', 'gray'])
plt.ylabel('Average Attention Weight')
plt.title('Attention Weights at Motif Positions')
plt.show()

修复后,我统计了 100 个样本,结果显示 Attention 机制极其显著地将权重集中在了 Motif 区域:

4.3 破译密码:模型到底看中了哪段序列?

既然 Attention 告诉了我们“在哪里”,那我们就可以顺藤摸瓜,看看那个位置到底是什么序列。如果模型真的学到了规则,提取出来的序列应该和我们的 Ground Truth (CGACCGAACTCC) 高度一致。

我写了一个脚本,把所有 Attention 权重最高的位置对应的 DNA 片段(长度设为 12)取出来,然后堆叠在一起看统计分布(Position Weight Matrix, PWM)。

# --- 4. Motif Reconstruction ---
def extract_consensus_motif(model, dataset, motif_len=12, num_samples=500):
    model.eval()
    aligned_sequences = []
    
    with torch.no_grad():
        count = 0
        for i in range(len(dataset)):
            if count >= num_samples: break
            
            seq, label = dataset[i]
            if label == 0: continue # Skip negatives
            
            # Forward pass to get attention weights
            _ = model(seq.unsqueeze(0))
            attn_weights = model.attention_weights.squeeze().cpu().numpy()
            
            # Upsample weights
            attn_tensor = torch.tensor(attn_weights).view(1, 1, -1)
            attn_upsampled = F.interpolate(attn_tensor, size=seq.size(1), mode='linear', align_corners=False)
            w = attn_upsampled.squeeze().numpy()
            
            # Find center of attention
            # Simple approach: argmax
            center_idx = np.argmax(w)
            
            # Extract window around center
            # Handle boundary conditions
            start = center_idx - motif_len // 2
            end = start + motif_len
            
            if start < 0:
                start = 0
                end = motif_len
            if end > seq.size(1):
                end = seq.size(1)
                start = end - motif_len
                
            # Convert One-Hot to String/Index
            # seq shape: (4, 50)
            seq_window = seq[:, start:end] # (4, 12)
            
            # Convert to indices (0,1,2,3)
            seq_indices = torch.argmax(seq_window, dim=0).numpy()
            aligned_sequences.append(seq_indices)
            
            count += 1
            
    return np.array(aligned_sequences)
 
# Extract sequences
aligned_seqs = extract_consensus_motif(attn_model, test_dataset, motif_len=12)
 
# Build PWM (Position Weight Matrix)
pwm = np.zeros((4, 12))
for i in range(12):
    col = aligned_seqs[:, i]
    counts = np.bincount(col, minlength=4)
    pwm[:, i] = counts / len(aligned_seqs)
 
# Visualize PWM
plt.figure(figsize=(10, 3))
sns.heatmap(pwm, annot=True, fmt='.2f', 
            xticklabels=range(1, 13), 
            yticklabels=['A', 'C', 'G', 'T'], cmap='Blues')
plt.title('Reconstructed Motif PWM (from Attention Maxima)')
plt.xlabel('Position relative to Attention Center')
plt.show()
 
# Decode Consensus
base_map = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
consensus = ""
for i in range(12):
    max_idx = np.argmax(pwm[:, i])
    consensus += base_map[max_idx]
 
print(f"Ground Truth Motif: CGACCGAACTCC")
print(f"Decoded Consensus:  {consensus}")

运行结果分析

Ground Truth Motif: CGACCGAACTCC
Decoded Consensus:  CACCCCCCCCCC

咦?结果并没有完全匹配,而是出现了很多 C。这是为什么?

这其实暴露了深度学习模型的一个“偷懒”特性(Shortcut Learning):

  1. 捷径学习:真实的 Motif CGACCGAACTCC 中包含了 50% 的 C。模型可能发现,只要识别出“C含量很高”的区域,就能以很高的概率蒙对结果。它并没有完整地学习碱基排列顺序,而是学习了统计特征。
  2. 分辨率丢失:别忘了我们在 CNN 中使用了 MaxPool1d(kernel_size=4)。这虽然减少了计算量,但也丢失了空间位置信息。Attention 机制是在池化后的特征上进行的(长度从 50 变成了 9),这意味着它只能定位到一个“模糊的区域”,而无法精确对齐到每一个碱基。

虽然没有完美复原 Motif,但这个结果依然证明了 Attention 定位到了正确的区域(也就是 C 含量高的那个 Motif 区域)。如果想要更精确的 Motif,我们可能需要去掉池化层,或者结合 Saliency Map 来分析。

4.4 回归本源:直接看卷积核 (The “Clear” View)

既然 Attention 看到的是“模糊”的景象,那有没有办法在这个模型里看到“清晰”的 Motif 呢?

当然有! 别忘了,虽然我们加了 Attention,但模型的第一层依然是 Conv1d。这一层是直接接触原始 DNA 序列的,它还没有经过 MaxPool 的压缩。

就像人的眼睛(卷积层)看得很清楚,但传到大脑(Attention)时变成了抽象的概念。如果我们直接检查“眼睛”看到的东西,应该能看到清晰的 Motif。

让我们把 attn_model 的第一层卷积核画出来看看,这里用到了基因组学深度学习入门指南定义的plot_motif_logos()函数:

plot_motif_logos(attn_model)

果然! 即使后面接了 Attention,第一层的卷积核依然学会了清晰的 Motif 模式(注意看那些高亮的色块,是不是和 CGACCGAACTCC 很像?)。

这再次印证了那个观点:卷积层负责提取细节特征(高分辨率),Attention 层负责整合全局信息(低分辨率但有大局观)。


5. 总结

这次折腾让我对 Attention 有了实感:

  1. 它确实有用:在基因组序列分析中,Attention 能帮助模型快速定位关键 Motif,提升收敛速度。
  2. 它可解释:相比于 CNN 的“黑盒”,Attention 权重提供了一个非常直观的窗口,让我们能看到模型在关注哪里。这对于生物学研究太重要了(比如发现新的 Motif)。
  3. 实现细节很重要:维度匹配、插值对齐、梯度控制,这些细节如果不注意,分分钟报错。

下一步,我打算试试 Transformer,看看全 Attention 架构能不能彻底取代 CNN。