Skip to content

Commit ed2f04c

Browse files
author
mohamedamri
committed
Add requirement.txt
1 parent b18f61a commit ed2f04c

File tree

11 files changed

+111
-362
lines changed

11 files changed

+111
-362
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
data
2+
.ipynb_checkpoints/*
23
*/.ipynb_checkpoints/*
34
mohamed
45
exp

.ipynb_checkpoints/main-checkpoint.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,66 +4,82 @@
44
from torch import optim
55
from dataset.dataset_retinal import DatasetRetinal
66
from augmentation.transforms import TransformImg, TransformImgMask
7-
from model.iternet_.iternet_model import Iternet
7+
from model.iternet.iternet_model import Iternet
88
from trainer.trainer import Trainer
99

1010
import argparse
1111

12+
1213
def main(args):
1314
# set the transform
1415
transform_img_mask = TransformImgMask(
15-
size=(args.size, args.size),
16-
size_crop=(args.crop_size, args.crop_size),
16+
size=(args.size, args.size),
17+
size_crop=(args.crop_size, args.crop_size),
1718
to_tensor=True
1819
)
19-
20+
2021
# set datasets
2122
csv_dir = {
2223
'train': args.train_csv,
2324
'val': args.val_csv
2425
}
2526
datasets = {
26-
x: DatasetRetinal(csv_dir[x],
27-
args.image_dir,
27+
x: DatasetRetinal(csv_dir[x],
28+
args.image_dir,
2829
args.mask_dir,
2930
batch_size=args.batch_size,
30-
transform_img_mask=transform_img_mask,
31+
transform_img_mask=transform_img_mask,
3132
transform_img=TransformImg()) for x in ['train', 'val']
3233
}
33-
34+
3435
# set dataloaders
3536
dataloaders = {
3637
x: DataLoader(datasets[x], batch_size=args.batch_size, shuffle=True) for x in ['train', 'val']
3738
}
38-
39+
3940
# initialize the model
4041
model = Iternet(n_channels=3, n_classes=1, out_channels=32, iterations=3)
41-
42+
4243
# set loss function and optimizer
4344
criteria = nn.BCEWithLogitsLoss()
44-
optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
45-
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if model.n_classes > 1 else 'max', patience=2)
46-
45+
optimizer = optim.RMSprop(
46+
model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
47+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
48+
optimizer, 'min' if model.n_classes > 1 else 'max', patience=2)
49+
4750
# train the model
48-
trainer = Trainer(model, criteria, optimizer, scheduler, args.gpus, args.seed)
51+
trainer = Trainer(model, criteria, optimizer,
52+
scheduler, args.gpus, args.seed)
4953
trainer(dataloaders, args.epochs, args.model_dir)
5054
torch.cuda.empty_cache()
5155

56+
5257
if __name__ == '__main__':
5358
parser = argparse.ArgumentParser(description='Model Training')
54-
parser.add_argument('--gpus', default='4,5,6', type=str, help='CUDA_VISIBLE_DEVICES')
55-
parser.add_argument('--size', default='592', type=int, help='CUDA_VISIBLE_DEVICES')
56-
parser.add_argument('--crop_size', default='128', type=int, help='CUDA_VISIBLE_DEVICES')
57-
parser.add_argument('--image_dir', default='data/stare/stare-images/', type=str, help='Images folder path')
58-
parser.add_argument('--mask_dir', default='data/stare/labels-ah/', type=str, help='Masks folder path')
59-
parser.add_argument('--train_csv', default='data/stare/train.csv', type=str, help='list of training set')
60-
parser.add_argument('--val_csv', default='data/stare/val.csv', type=str, help='list of validation set')
61-
parser.add_argument('--lr', default='0.001', type=float, help='learning rate')
62-
parser.add_argument('--epochs', default='2', type=int, help='Number of epochs')
63-
parser.add_argument('--batch_size', default='32', type=int, help='Batch Size')
64-
parser.add_argument('--model_dir', default='exp/', type=str, help='Images folder path')
65-
parser.add_argument('--seed', default='2020123', type=int, help='Random status')
59+
parser.add_argument('--gpus', default='4,5,6',
60+
type=str, help='CUDA_VISIBLE_DEVICES')
61+
parser.add_argument('--size', default='592', type=int,
62+
help='CUDA_VISIBLE_DEVICES')
63+
parser.add_argument('--crop_size', default='128',
64+
type=int, help='CUDA_VISIBLE_DEVICES')
65+
parser.add_argument('--image_dir', default='data/stare/stare-images/',
66+
type=str, help='Images folder path')
67+
parser.add_argument('--mask_dir', default='data/stare/labels-ah/',
68+
type=str, help='Masks folder path')
69+
parser.add_argument('--train_csv', default='data/stare/train.csv',
70+
type=str, help='list of training set')
71+
parser.add_argument('--val_csv', default='data/stare/val.csv',
72+
type=str, help='list of validation set')
73+
parser.add_argument('--lr', default='0.0001',
74+
type=float, help='learning rate')
75+
parser.add_argument('--epochs', default='2',
76+
type=int, help='Number of epochs')
77+
parser.add_argument('--batch_size', default='32',
78+
type=int, help='Batch Size')
79+
parser.add_argument('--model_dir', default='exp/',
80+
type=str, help='Images folder path')
81+
parser.add_argument('--seed', default='2020123',
82+
type=int, help='Random status')
6683
args = parser.parse_args()
67-
84+
6885
main(args)
69-

main.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,66 +4,82 @@
44
from torch import optim
55
from dataset.dataset_retinal import DatasetRetinal
66
from augmentation.transforms import TransformImg, TransformImgMask
7-
from model.iternet_.iternet_model import Iternet
7+
from model.iternet.iternet_model import Iternet
88
from trainer.trainer import Trainer
99

1010
import argparse
1111

12+
1213
def main(args):
1314
# set the transform
1415
transform_img_mask = TransformImgMask(
15-
size=(args.size, args.size),
16-
size_crop=(args.crop_size, args.crop_size),
16+
size=(args.size, args.size),
17+
size_crop=(args.crop_size, args.crop_size),
1718
to_tensor=True
1819
)
19-
20+
2021
# set datasets
2122
csv_dir = {
2223
'train': args.train_csv,
2324
'val': args.val_csv
2425
}
2526
datasets = {
26-
x: DatasetRetinal(csv_dir[x],
27-
args.image_dir,
27+
x: DatasetRetinal(csv_dir[x],
28+
args.image_dir,
2829
args.mask_dir,
2930
batch_size=args.batch_size,
30-
transform_img_mask=transform_img_mask,
31+
transform_img_mask=transform_img_mask,
3132
transform_img=TransformImg()) for x in ['train', 'val']
3233
}
33-
34+
3435
# set dataloaders
3536
dataloaders = {
3637
x: DataLoader(datasets[x], batch_size=args.batch_size, shuffle=True) for x in ['train', 'val']
3738
}
38-
39+
3940
# initialize the model
4041
model = Iternet(n_channels=3, n_classes=1, out_channels=32, iterations=3)
41-
42+
4243
# set loss function and optimizer
4344
criteria = nn.BCEWithLogitsLoss()
44-
optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
45-
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if model.n_classes > 1 else 'max', patience=2)
46-
45+
optimizer = optim.RMSprop(
46+
model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
47+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
48+
optimizer, 'min' if model.n_classes > 1 else 'max', patience=2)
49+
4750
# train the model
48-
trainer = Trainer(model, criteria, optimizer, scheduler, args.gpus, args.seed)
51+
trainer = Trainer(model, criteria, optimizer,
52+
scheduler, args.gpus, args.seed)
4953
trainer(dataloaders, args.epochs, args.model_dir)
5054
torch.cuda.empty_cache()
5155

56+
5257
if __name__ == '__main__':
5358
parser = argparse.ArgumentParser(description='Model Training')
54-
parser.add_argument('--gpus', default='4,5,6', type=str, help='CUDA_VISIBLE_DEVICES')
55-
parser.add_argument('--size', default='592', type=int, help='CUDA_VISIBLE_DEVICES')
56-
parser.add_argument('--crop_size', default='128', type=int, help='CUDA_VISIBLE_DEVICES')
57-
parser.add_argument('--image_dir', default='data/stare/stare-images/', type=str, help='Images folder path')
58-
parser.add_argument('--mask_dir', default='data/stare/labels-ah/', type=str, help='Masks folder path')
59-
parser.add_argument('--train_csv', default='data/stare/train.csv', type=str, help='list of training set')
60-
parser.add_argument('--val_csv', default='data/stare/val.csv', type=str, help='list of validation set')
61-
parser.add_argument('--lr', default='0.001', type=float, help='learning rate')
62-
parser.add_argument('--epochs', default='2', type=int, help='Number of epochs')
63-
parser.add_argument('--batch_size', default='32', type=int, help='Batch Size')
64-
parser.add_argument('--model_dir', default='exp/', type=str, help='Images folder path')
65-
parser.add_argument('--seed', default='2020123', type=int, help='Random status')
59+
parser.add_argument('--gpus', default='4,5,6',
60+
type=str, help='CUDA_VISIBLE_DEVICES')
61+
parser.add_argument('--size', default='592', type=int,
62+
help='CUDA_VISIBLE_DEVICES')
63+
parser.add_argument('--crop_size', default='128',
64+
type=int, help='CUDA_VISIBLE_DEVICES')
65+
parser.add_argument('--image_dir', default='data/stare/stare-images/',
66+
type=str, help='Images folder path')
67+
parser.add_argument('--mask_dir', default='data/stare/labels-ah/',
68+
type=str, help='Masks folder path')
69+
parser.add_argument('--train_csv', default='data/stare/train.csv',
70+
type=str, help='list of training set')
71+
parser.add_argument('--val_csv', default='data/stare/val.csv',
72+
type=str, help='list of validation set')
73+
parser.add_argument('--lr', default='0.0001',
74+
type=float, help='learning rate')
75+
parser.add_argument('--epochs', default='2',
76+
type=int, help='Number of epochs')
77+
parser.add_argument('--batch_size', default='32',
78+
type=int, help='Batch Size')
79+
parser.add_argument('--model_dir', default='exp/',
80+
type=str, help='Images folder path')
81+
parser.add_argument('--seed', default='2020123',
82+
type=int, help='Random status')
6683
args = parser.parse_args()
67-
84+
6885
main(args)
69-

model/iternet/iternet_model.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ def forward(self, x):
3838
x = self.up4(x, x1)
3939
logits = self.outc(x)
4040
return x1, x, logits
41-
41+
42+
4243
class MiniUNet(nn.Module):
4344
def __init__(self, n_channels, n_classes, out_channels=32):
4445
super(MiniUNet, self).__init__()
4546
self.n_channels = n_channels
4647
self.n_classes = n_classes
4748
bilinear = False
48-
49+
4950
self.inc = DoubleConv(n_channels, out_channels)
5051
self.down1 = Down(out_channels, out_channels*2)
5152
self.down2 = Down(out_channels*2, out_channels*4)
@@ -65,24 +66,27 @@ def forward(self, x):
6566
x = self.up3(x, x1)
6667
logits = self.outc(x)
6768
return x1, x, logits
68-
69+
70+
6971
class Iternet(nn.Module):
7072
def __init__(self, n_channels, n_classes, out_channels=32, iterations=3):
7173
super(Iternet, self).__init__()
7274
self.n_channels = n_channels
7375
self.n_classes = n_classes
7476
self.iterations = iterations
75-
77+
7678
# define the network UNet layer
77-
self.model_unet = UNet(n_channels=n_channels, n_classes=n_classes, out_channels=out_channels)
78-
79+
self.model_unet = UNet(n_channels=n_channels,
80+
n_classes=n_classes, out_channels=out_channels)
81+
7982
# define the network MiniUNet layers
80-
self.model_miniunet = ModuleList(MiniUNet(n_channels=out_channels*2, n_classes=n_classes, out_channels=out_channels) for i in range(iterations))
81-
83+
self.model_miniunet = ModuleList(MiniUNet(
84+
n_channels=out_channels*2, n_classes=n_classes, out_channels=out_channels) for i in range(iterations))
85+
8286
def forward(self, x):
8387
x1, x2, logits = self.model_unet(x)
8488
for i in range(self.iterations):
8589
x = torch.cat([x1, x2], dim=1)
8690
_, x2, logits = self.model_miniunet[i](x)
87-
88-
return logits
91+
92+
return logits

model/unet/.ipynb_checkpoints/unet_model-checkpoint.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

0 commit comments

Comments
 (0)