classDataset(torch.utils.data.IterableDataset): """ Dataset for InversionNet. """
def__init__(self, root_dir, fid_list, num_samples_per_file=500): """ Initialize dataset. Args: root_dir: root directory. fid_list: list of npy file id. num_samples_per_file: number of npy samples, which is 500 for OpenFWI. """ super().__init__() self.data_files = [ os.path.join(root_dir, "data", f"data{fid}.npy") for fid in fid_list ] self.label_files = [ os.path.join(root_dir, "model", f"model{fid}.npy") for fid in fid_list ] self.num_samples_per_file = num_samples_per_file
需要注意的是,当设置num_workers为非零常数(使用multi-process data loading)后,系统会直接将整个dataset复制到每个进程中。因此在实现__iter__()时,需要手动做分片操作。
Model
之前跟Dive into deep learning时已经了解了大概,此处不展开。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
classInversionNet(nn.Module): """ My InversionNet consisting of convolution block and deconvolution block. """
def__init__(self, dim1=32, dim2=64, dim3=128, dim4=256, dim5=512, **kwargs): """ Args: dim1: Number of channels in the 1st layer dim2: Number of channels in the 2nd layer dim3: Number of channels in the 3rd layer dim4: Number of channels in the 4th layer dim5: Number of channels in the 5th layer """ super().__init__() // Init code
defparse(): """ Create a new parser. """ parser = argparse.ArgumentParser( prog="Trainer", description="InversionNet Pytorch Trainer" ) parser.add_argument( "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)", ) parser.add_argument( "--train-size", type=float, default=0.8, metavar="LR", help="proportion for training dataset (default: 0.8)", ) parser.add_argument( "--epochs", type=int, default=1000, metavar="N", help="number of epochs to train (default: 1000)", ) parser.add_argument( "--lr", type=float, default=1e-2, metavar="LR", help="learning rate (default: 1e-2)", ) parser.add_argument( "--log-interval", type=int, default=1, metavar="N", help="how many batches to wait before logging training status", ) parser.add_argument( "--no-save-model", action="store_true", default=False, help="do not save the current Model", ) parser.add_argument( "--seed", type=int, default=42, metavar="S", help="random seed (default: 42)" ) parser.add_argument( "--num-workers", type=int, default=0, metavar="N", help="number of processes if using multi-process data loading", ) parser.add_argument( "--no-cuda", action="store_true", default=False, help="disables CUDA training" ) parser.add_argument( "--no-mps", action="store_true", default=False, help="disables macOS GPU training", ) return parser
环境设置
设置训练所需的各种环境。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
torch.manual_seed(args.seed) use_cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() if use_mps: assert ( use_mps and args.num_workers == 0 ), "No support for multiprocess dataload using mps" kwargs = { "batch_size": args.batch_size, "num_workers": args.num_workers, "pin_memory": True, } if use_cuda: device = torch.device("cuda") elif use_mps: device = torch.device("mps") else: device = torch.device("cpu")
r = random.Random(args.seed) num_file = len( [file for file in os.listdir("./data/data") if file.endswith(".npy")] ) file_idx_list = [i for i inrange(1, num_file + 1)] r.shuffle(file_idx_list)