@@ -7,16 +7,21 @@ def _convert_sequential(builder, name, layer, input_names, output_names):
7
7
8
8
inputs = input_names
9
9
for i in range (n ):
10
- l = layers [i ]
10
+ l_ = layers [i ]
11
11
12
12
l_outputs = None
13
- l_name = _gen_layer_name (l )
13
+ l_name = _gen_layer_name (l_ )
14
14
if i != (n - 1 ):
15
- l_outputs = [l_name ]
15
+ if isinstance (l_ .output , list ):
16
+ l_outputs = [
17
+ "{}_{}" .format (l_name , i )for i in range (len (l_ .output ))
18
+ ]
19
+ else :
20
+ l_outputs = [l_name ]
16
21
else :
17
22
l_outputs = output_names
18
23
19
- l_outputs = _convert_layer (builder , l_name , l , inputs , l_outputs )
24
+ l_outputs = _convert_layer (builder , l_name , l_ , inputs , l_outputs )
20
25
inputs = l_outputs
21
26
22
27
return output_names
@@ -154,6 +159,20 @@ def _convert_concat_table(builder, name, layer, input_names, output_names):
154
159
return result_outputs
155
160
156
161
162
+ def _convert_parallel_table (builder , name , layer , input_names , output_names ):
163
+ layers = layer .modules
164
+ assert len (input_names ) == len (layers )
165
+ result_outputs = []
166
+ for i in range (len (layers )):
167
+ l_ = layers [i ]
168
+ l_name = _gen_layer_name (l_ )
169
+ l_outputs = _convert_layer (
170
+ builder , l_name , l_ , [input_names [i ]], [l_name ]
171
+ )
172
+ result_outputs .append (l_outputs [0 ])
173
+ return result_outputs
174
+
175
+
157
176
def _convert_batch_norm (builder , name , layer , input_names , output_names ):
158
177
epsilon = layer .eps
159
178
mean = layer .running_mean .numpy ()
@@ -192,6 +211,43 @@ def _convert_cadd_table(builder, name, layer, input_names, output_names):
192
211
return output_names
193
212
194
213
214
+ def _convert_cdiv_table (builder , name , layer , input_names , output_names ):
215
+ assert len (input_names ) == 2
216
+ assert len (output_names ) == 1
217
+
218
+ inverse_layer_name = _gen_layer_name ('inverse' )
219
+ inverse_layer_output_name = inverse_layer_name + '_output'
220
+
221
+ builder .add_unary (
222
+ name = inverse_layer_name ,
223
+ input_name = input_names [1 ],
224
+ output_name = inverse_layer_output_name ,
225
+ mode = 'inverse'
226
+ )
227
+
228
+ builder .add_elementwise (
229
+ name = name ,
230
+ input_names = [input_names [0 ], inverse_layer_output_name ],
231
+ output_name = output_names [0 ],
232
+ mode = 'MULTIPLY'
233
+ )
234
+ return output_names
235
+
236
+
237
+ def _convert_cmul_table (builder , name , layer , input_names , output_names ):
238
+ assert len (input_names ) > 1
239
+ assert len (output_names ) == 1
240
+
241
+ builder .add_elementwise (
242
+ name = name ,
243
+ input_names = input_names ,
244
+ output_name = output_names [0 ],
245
+ mode = 'MULTIPLY'
246
+ )
247
+
248
+ return output_names
249
+
250
+
195
251
def _convert_identity (builder , name , layer , input_names , output_names ):
196
252
return input_names
197
253
@@ -299,13 +355,12 @@ def _convert_tanh(builder, name, layer, input_names, output_names):
299
355
300
356
def _convert_mul_constant (builder , name , layer , input_names , output_names ):
301
357
scalar = float (layer .constant_scalar )
302
-
303
- builder .add_activation (
358
+ builder .add_elementwise (
304
359
name = name ,
305
- non_linearity = 'LINEAR' ,
306
- input_name = input_names [0 ],
360
+ input_names = [input_names [0 ]],
307
361
output_name = output_names [0 ],
308
- params = [scalar , 0.0 ]
362
+ mode = 'MULTIPLY' ,
363
+ alpha = scalar
309
364
)
310
365
311
366
return output_names
@@ -438,6 +493,46 @@ def _convert_split_table(builder, name, layer, input_names, output_names):
438
493
return output_names
439
494
440
495
496
+ def _convert_log (builder , name , layer , input_names , output_names ):
497
+ builder .add_unary (
498
+ name = name ,
499
+ input_name = input_names [0 ],
500
+ output_name = output_names [0 ],
501
+ mode = 'log'
502
+ )
503
+ return output_names
504
+
505
+
506
+ def _convert_sigmoid (builder , name , layer , input_names , output_names ):
507
+ builder .add_activation (
508
+ name = name ,
509
+ non_linearity = 'SIGMOID' ,
510
+ input_name = input_names [0 ],
511
+ output_name = output_names [0 ]
512
+ )
513
+ return output_names
514
+
515
+
516
+ def _convert_power (builder , name , layer , input_names , output_names ):
517
+ p = layer .pow
518
+ if p == - 1 :
519
+ builder .add_unary (
520
+ name = name ,
521
+ input_name = input_names [0 ],
522
+ output_name = output_names [0 ],
523
+ mode = 'inverse'
524
+ )
525
+ else :
526
+ builder .add_unary (
527
+ name = name ,
528
+ input_name = input_names [0 ],
529
+ output_name = output_names [0 ],
530
+ mode = 'power' ,
531
+ alpha = float (p )
532
+ )
533
+ return output_names
534
+
535
+
441
536
_TORCH_LAYER_REGISTRY = {
442
537
'Sequential' : _convert_sequential ,
443
538
'SpatialConvolution' : _convert_convolution ,
@@ -460,7 +555,13 @@ def _convert_split_table(builder, name, layer, input_names, output_names):
460
555
'Narrow' : _convert_narrow ,
461
556
'SpatialReflectionPadding' : _convert_reflection_padding ,
462
557
'SpatialUpSamplingNearest' : _convert_upsampling_nearest ,
463
- 'SplitTable' : _convert_split_table
558
+ 'SplitTable' : _convert_split_table ,
559
+ 'CDivTable' : _convert_cdiv_table ,
560
+ 'Log' : _convert_log ,
561
+ 'Sigmoid' : _convert_sigmoid ,
562
+ 'ParallelTable' : _convert_parallel_table ,
563
+ 'Power' : _convert_power ,
564
+ 'CMulTable' : _convert_cmul_table ,
464
565
}
465
566
466
567
0 commit comments