Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 454e1b3

Browse filesBrowse files
committed
change code to newest pytorch
1 parent fb4d466 commit 454e1b3
Copy full SHA for 454e1b3

File tree

Expand file treeCollapse file tree

5 files changed

+30
-36
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

5 files changed

+30
-36
lines changed
Open diff view settings
Collapse file

‎main.py‎

Copy file name to clipboardExpand all lines: main.py
+19-23Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.optim as optim
99
import torch.nn.functional as F
1010
import torch.backends.cudnn as cudnn
11-
from torch.autograd import Variable
1211

1312

1413
import torchvision
@@ -206,16 +205,14 @@ def train(trainloader, model, criterion, optimizer, epoch):
206205
data_time.update(time.time() - end)
207206

208207
input, target = input.cuda(), target.cuda()
209-
input_var = Variable(input)
210-
target_var = Variable(target)
211208

212209
# compute output
213-
output = model(input_var)
214-
loss = criterion(output, target_var)
210+
output = model(input)
211+
loss = criterion(output, target)
215212

216213
# measure accuracy and record loss
217-
prec = accuracy(output.data, target)[0]
218-
losses.update(loss.data[0], input.size(0))
214+
prec = accuracy(output, target)[0]
215+
losses.update(loss.item(), input.size(0))
219216
top1.update(prec[0], input.size(0))
220217

221218
# compute gradient and do SGD step
@@ -246,26 +243,25 @@ def validate(val_loader, model, criterion):
246243
model.eval()
247244

248245
end = time.time()
249-
for i, (input, target) in enumerate(val_loader):
250-
input, target = input.cuda(), target.cuda()
251-
input_var = Variable(input, volatile=True)
252-
target_var = Variable(target, volatile=True)
246+
with torch.no_grad():
247+
for i, (input, target) in enumerate(val_loader):
248+
input, target = input.cuda(), target.cuda()
253249

254-
# compute output
255-
output = model(input_var)
256-
loss = criterion(output, target_var)
250+
# compute output
251+
output = model(input)
252+
loss = criterion(output, target)
257253

258-
# measure accuracy and record loss
259-
prec = accuracy(output.data, target)[0]
260-
losses.update(loss.data[0], input.size(0))
261-
top1.update(prec[0], input.size(0))
254+
# measure accuracy and record loss
255+
prec = accuracy(output, target)[0]
256+
losses.update(loss.item(), input.size(0))
257+
top1.update(prec[0], input.size(0))
262258

263-
# measure elapsed time
264-
batch_time.update(time.time() - end)
265-
end = time.time()
259+
# measure elapsed time
260+
batch_time.update(time.time() - end)
261+
end = time.time()
266262

267-
if i % args.print_freq == 0:
268-
print('Test: [{0}/{1}]\t'
263+
if i % args.print_freq == 0:
264+
print('Test: [{0}/{1}]\t'
269265
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
270266
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
271267
'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
Collapse file

‎models/densenet_cifar.py‎

Copy file name to clipboardExpand all lines: models/densenet_cifar.py
+7-7Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
class _DenseLayer(nn.Sequential):
1616
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
1717
super(_DenseLayer, self).__init__()
18-
self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),
19-
self.add_module('relu.1', nn.ReLU(inplace=True)),
20-
self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
18+
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
19+
self.add_module('relu1', nn.ReLU(inplace=True)),
20+
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
2121
growth_rate, kernel_size=1, stride=1, bias=False)),
22-
self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
23-
self.add_module('relu.2', nn.ReLU(inplace=True)),
24-
self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
22+
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
23+
self.add_module('relu2', nn.ReLU(inplace=True)),
24+
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
2525
kernel_size=3, stride=1, padding=1, bias=False)),
2626
self.drop_rate = drop_rate
2727

@@ -116,7 +116,7 @@ def densenet_BC_cifar(depth, k, **kwargs):
116116

117117
if __name__ == '__main__':
118118
net = densenet_BC_cifar(190, 40, num_classes=100)
119-
input = torch.autograd.Variable(torch.randn(1, 3, 32, 32))
119+
input = torch.randn(1, 3, 32, 32)
120120
y = net(input)
121121
print(net)
122122
print(y.size())
Collapse file

‎models/resnet_cifar.py‎

Copy file name to clipboardExpand all lines: models/resnet_cifar.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ def preact_resnet1001_cifar(**kwargs):
321321

322322

323323
if __name__ == '__main__':
324-
net = preact_resnet110_cifar()
325-
y = net(torch.autograd.Variable(torch.randn(1, 3, 32, 32)))
324+
net = resnet20_cifar()
325+
y = net(torch.randn(1, 3, 64, 64))
326326
print(net)
327327
print(y.size())
328328

Collapse file

‎models/resnext_cifar.py‎

Copy file name to clipboardExpand all lines: models/resnext_cifar.py
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
import torch.nn as nn
10-
from torch.autograd import Variable
1110
import math
1211

1312

@@ -118,6 +117,6 @@ def resneXt_cifar(depth, cardinality, baseWidth, **kwargs):
118117

119118
if __name__ == '__main__':
120119
net = resneXt_cifar(29, 16, 64)
121-
y = net(Variable(torch.randn(1, 3, 32, 32)))
120+
y = net(torch.randn(1, 3, 32, 32))
122121
print(net)
123122
print(y.size())
Collapse file

‎models/wide_resnet_cifar.py‎

Copy file name to clipboardExpand all lines: models/wide_resnet_cifar.py
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
import torch
88
import torch.nn as nn
9-
from torch.autograd import Variable
109
import math
1110
from resnet_cifar import BasicBlock
1211

@@ -73,7 +72,7 @@ def wide_resnet_cifar(depth, width, **kwargs):
7372

7473
if __name__=='__main__':
7574
net = wide_resnet_cifar(20, 10)
76-
y = net(Variable(torch.randn(1, 3, 32, 32)))
75+
y = net(torch.randn(1, 3, 32, 32))
7776
print(isinstance(net, Wide_ResNet_Cifar))
7877
print(y.size())
7978

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.