44from torch import optim
55from dataset .dataset_retinal import DatasetRetinal
66from augmentation .transforms import TransformImg , TransformImgMask
7- from model .iternet_ .iternet_model import Iternet
7+ from model .iternet .iternet_model import Iternet
88from trainer .trainer import Trainer
99
1010import argparse
1111
12+
1213def main (args ):
1314 # set the transform
1415 transform_img_mask = TransformImgMask (
15- size = (args .size , args .size ),
16- size_crop = (args .crop_size , args .crop_size ),
16+ size = (args .size , args .size ),
17+ size_crop = (args .crop_size , args .crop_size ),
1718 to_tensor = True
1819 )
19-
20+
2021 # set datasets
2122 csv_dir = {
2223 'train' : args .train_csv ,
2324 'val' : args .val_csv
2425 }
2526 datasets = {
26- x : DatasetRetinal (csv_dir [x ],
27- args .image_dir ,
27+ x : DatasetRetinal (csv_dir [x ],
28+ args .image_dir ,
2829 args .mask_dir ,
2930 batch_size = args .batch_size ,
30- transform_img_mask = transform_img_mask ,
31+ transform_img_mask = transform_img_mask ,
3132 transform_img = TransformImg ()) for x in ['train' , 'val' ]
3233 }
33-
34+
3435 # set dataloaders
3536 dataloaders = {
3637 x : DataLoader (datasets [x ], batch_size = args .batch_size , shuffle = True ) for x in ['train' , 'val' ]
3738 }
38-
39+
3940 # initialize the model
4041 model = Iternet (n_channels = 3 , n_classes = 1 , out_channels = 32 , iterations = 3 )
41-
42+
4243 # set loss function and optimizer
4344 criteria = nn .BCEWithLogitsLoss ()
44- optimizer = optim .RMSprop (model .parameters (), lr = args .lr , weight_decay = 1e-8 , momentum = 0.9 )
45- scheduler = optim .lr_scheduler .ReduceLROnPlateau (optimizer , 'min' if model .n_classes > 1 else 'max' , patience = 2 )
46-
45+ optimizer = optim .RMSprop (
46+ model .parameters (), lr = args .lr , weight_decay = 1e-8 , momentum = 0.9 )
47+ scheduler = optim .lr_scheduler .ReduceLROnPlateau (
48+ optimizer , 'min' if model .n_classes > 1 else 'max' , patience = 2 )
49+
4750 # train the model
48- trainer = Trainer (model , criteria , optimizer , scheduler , args .gpus , args .seed )
51+ trainer = Trainer (model , criteria , optimizer ,
52+ scheduler , args .gpus , args .seed )
4953 trainer (dataloaders , args .epochs , args .model_dir )
5054 torch .cuda .empty_cache ()
5155
56+
5257if __name__ == '__main__' :
5358 parser = argparse .ArgumentParser (description = 'Model Training' )
54- parser .add_argument ('--gpus' , default = '4,5,6' , type = str , help = 'CUDA_VISIBLE_DEVICES' )
55- parser .add_argument ('--size' , default = '592' , type = int , help = 'CUDA_VISIBLE_DEVICES' )
56- parser .add_argument ('--crop_size' , default = '128' , type = int , help = 'CUDA_VISIBLE_DEVICES' )
57- parser .add_argument ('--image_dir' , default = 'data/stare/stare-images/' , type = str , help = 'Images folder path' )
58- parser .add_argument ('--mask_dir' , default = 'data/stare/labels-ah/' , type = str , help = 'Masks folder path' )
59- parser .add_argument ('--train_csv' , default = 'data/stare/train.csv' , type = str , help = 'list of training set' )
60- parser .add_argument ('--val_csv' , default = 'data/stare/val.csv' , type = str , help = 'list of validation set' )
61- parser .add_argument ('--lr' , default = '0.001' , type = float , help = 'learning rate' )
62- parser .add_argument ('--epochs' , default = '2' , type = int , help = 'Number of epochs' )
63- parser .add_argument ('--batch_size' , default = '32' , type = int , help = 'Batch Size' )
64- parser .add_argument ('--model_dir' , default = 'exp/' , type = str , help = 'Images folder path' )
65- parser .add_argument ('--seed' , default = '2020123' , type = int , help = 'Random status' )
59+ parser .add_argument ('--gpus' , default = '4,5,6' ,
60+ type = str , help = 'CUDA_VISIBLE_DEVICES' )
61+ parser .add_argument ('--size' , default = '592' , type = int ,
62+ help = 'CUDA_VISIBLE_DEVICES' )
63+ parser .add_argument ('--crop_size' , default = '128' ,
64+ type = int , help = 'CUDA_VISIBLE_DEVICES' )
65+ parser .add_argument ('--image_dir' , default = 'data/stare/stare-images/' ,
66+ type = str , help = 'Images folder path' )
67+ parser .add_argument ('--mask_dir' , default = 'data/stare/labels-ah/' ,
68+ type = str , help = 'Masks folder path' )
69+ parser .add_argument ('--train_csv' , default = 'data/stare/train.csv' ,
70+ type = str , help = 'list of training set' )
71+ parser .add_argument ('--val_csv' , default = 'data/stare/val.csv' ,
72+ type = str , help = 'list of validation set' )
73+ parser .add_argument ('--lr' , default = '0.0001' ,
74+ type = float , help = 'learning rate' )
75+ parser .add_argument ('--epochs' , default = '2' ,
76+ type = int , help = 'Number of epochs' )
77+ parser .add_argument ('--batch_size' , default = '32' ,
78+ type = int , help = 'Batch Size' )
79+ parser .add_argument ('--model_dir' , default = 'exp/' ,
80+ type = str , help = 'Images folder path' )
81+ parser .add_argument ('--seed' , default = '2020123' ,
82+ type = int , help = 'Random status' )
6683 args = parser .parse_args ()
67-
84+
6885 main (args )
69-
0 commit comments