import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
class MyDataSet(Dataset):
def __init__(self, input, label, classes):
super(MyDataSet, self).__init__()
self.data = input
self.label = label
self.classes = classes
self.len = input.size(0)
def __len__(self):
return self.len
def __getitem__(self, index):
return self.data[index, :], self.label[index, :], self.classes[index]
class EncoderDecoderNet(nn.Module):
def __init__(self):
super(EncoderDecoderNet, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(2, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
nn.Linear(64, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.2),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
)
self.decoder = nn.Sequential(
nn.Linear(64, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.2),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
nn.Linear(64, 2),
)
self.skip_connection = nn.Linear(2, 2)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded + self.skip_connection(x)
class ComplexMSELossWithPhase(nn.Module):
def __init__(self):
super(ComplexMSELossWithPhase, self