Bootstrap

python 代码实现了一个基于编码器 - 解码器网络的模型,用于处理和预测与星座点相关的数据

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt

# 自定义数据集
class MyDataSet(Dataset):
    def __init__(self, input, label, classes):
        super(MyDataSet, self).__init__()
        self.data = input
        self.label = label
        self.classes = classes
        self.len = input.size(0)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return self.data[index, :], self.label[index, :], self.classes[index]

class EncoderDecoderNet(nn.Module):
    def __init__(self):
        super(EncoderDecoderNet, self).__init__()

        # Encoder: Compress distorted points into a latent representation
        self.encoder = nn.Sequential(
            nn.Linear(2, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),

            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),

            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
        )

        # Decoder: Expand the latent representation back to the original space
        self.decoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),

            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),

            nn.Linear(64, 2),  # Output ideal constellation points (I, Q)
        )

        # Skip connection (for residual learning)
        self.skip_connection = nn.Linear(2, 2)

    def forward(self, x):
        # Encoder step
        encoded = self.encoder(x)
        # Decoder step
        decoded = self.decoder(encoded)
        # Add skip connection
        return decoded + self.skip_connection(x)

# 损失函数:结合了MSE和相位损失
class ComplexMSELossWithPhase(nn.Module):
    def __init__(self):
        super(ComplexMSELossWithPhase, self
;