Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on Sep 25, 2023. It is now read-only.

Commit ee461ba

Browse filesBrowse files
author
Oleg Poyaganov
committed
add new layers
1 parent d613a9f commit ee461ba
Copy full SHA for ee461ba

File tree

Expand file treeCollapse file tree

2 files changed

+117
-10
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+117
-10
lines changed

‎test/test_layers.py

Copy file name to clipboardExpand all lines: test/test_layers.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,9 @@ def test_split_table(self):
181181
self.output_count = 3
182182
self.torch_batch_mode = False
183183
self._test_single_layer(nn.SplitTable(0))
184+
185+
def test_sigmoid(self):
186+
self._test_single_layer(nn.Sigmoid())
187+
188+
def test_power(self):
189+
self._test_single_layer(nn.Power(2))

‎torch2coreml/_layers.py

Copy file name to clipboardExpand all lines: torch2coreml/_layers.py
+111-10Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@ def _convert_sequential(builder, name, layer, input_names, output_names):
77

88
inputs = input_names
99
for i in range(n):
10-
l = layers[i]
10+
l_ = layers[i]
1111

1212
l_outputs = None
13-
l_name = _gen_layer_name(l)
13+
l_name = _gen_layer_name(l_)
1414
if i != (n - 1):
15-
l_outputs = [l_name]
15+
if isinstance(l_.output, list):
16+
l_outputs = [
17+
"{}_{}".format(l_name, i)for i in range(len(l_.output))
18+
]
19+
else:
20+
l_outputs = [l_name]
1621
else:
1722
l_outputs = output_names
1823

19-
l_outputs = _convert_layer(builder, l_name, l, inputs, l_outputs)
24+
l_outputs = _convert_layer(builder, l_name, l_, inputs, l_outputs)
2025
inputs = l_outputs
2126

2227
return output_names
@@ -154,6 +159,20 @@ def _convert_concat_table(builder, name, layer, input_names, output_names):
154159
return result_outputs
155160

156161

162+
def _convert_parallel_table(builder, name, layer, input_names, output_names):
163+
layers = layer.modules
164+
assert len(input_names) == len(layers)
165+
result_outputs = []
166+
for i in range(len(layers)):
167+
l_ = layers[i]
168+
l_name = _gen_layer_name(l_)
169+
l_outputs = _convert_layer(
170+
builder, l_name, l_, [input_names[i]], [l_name]
171+
)
172+
result_outputs.append(l_outputs[0])
173+
return result_outputs
174+
175+
157176
def _convert_batch_norm(builder, name, layer, input_names, output_names):
158177
epsilon = layer.eps
159178
mean = layer.running_mean.numpy()
@@ -192,6 +211,43 @@ def _convert_cadd_table(builder, name, layer, input_names, output_names):
192211
return output_names
193212

194213

214+
def _convert_cdiv_table(builder, name, layer, input_names, output_names):
215+
assert len(input_names) == 2
216+
assert len(output_names) == 1
217+
218+
inverse_layer_name = _gen_layer_name('inverse')
219+
inverse_layer_output_name = inverse_layer_name + '_output'
220+
221+
builder.add_unary(
222+
name=inverse_layer_name,
223+
input_name=input_names[1],
224+
output_name=inverse_layer_output_name,
225+
mode='inverse'
226+
)
227+
228+
builder.add_elementwise(
229+
name=name,
230+
input_names=[input_names[0], inverse_layer_output_name],
231+
output_name=output_names[0],
232+
mode='MULTIPLY'
233+
)
234+
return output_names
235+
236+
237+
def _convert_cmul_table(builder, name, layer, input_names, output_names):
238+
assert len(input_names) > 1
239+
assert len(output_names) == 1
240+
241+
builder.add_elementwise(
242+
name=name,
243+
input_names=input_names,
244+
output_name=output_names[0],
245+
mode='MULTIPLY'
246+
)
247+
248+
return output_names
249+
250+
195251
def _convert_identity(builder, name, layer, input_names, output_names):
196252
return input_names
197253

@@ -299,13 +355,12 @@ def _convert_tanh(builder, name, layer, input_names, output_names):
299355

300356
def _convert_mul_constant(builder, name, layer, input_names, output_names):
301357
scalar = float(layer.constant_scalar)
302-
303-
builder.add_activation(
358+
builder.add_elementwise(
304359
name=name,
305-
non_linearity='LINEAR',
306-
input_name=input_names[0],
360+
input_names=[input_names[0]],
307361
output_name=output_names[0],
308-
params=[scalar, 0.0]
362+
mode='MULTIPLY',
363+
alpha=scalar
309364
)
310365

311366
return output_names
@@ -438,6 +493,46 @@ def _convert_split_table(builder, name, layer, input_names, output_names):
438493
return output_names
439494

440495

496+
def _convert_log(builder, name, layer, input_names, output_names):
497+
builder.add_unary(
498+
name=name,
499+
input_name=input_names[0],
500+
output_name=output_names[0],
501+
mode='log'
502+
)
503+
return output_names
504+
505+
506+
def _convert_sigmoid(builder, name, layer, input_names, output_names):
507+
builder.add_activation(
508+
name=name,
509+
non_linearity='SIGMOID',
510+
input_name=input_names[0],
511+
output_name=output_names[0]
512+
)
513+
return output_names
514+
515+
516+
def _convert_power(builder, name, layer, input_names, output_names):
517+
p = layer.pow
518+
if p == -1:
519+
builder.add_unary(
520+
name=name,
521+
input_name=input_names[0],
522+
output_name=output_names[0],
523+
mode='inverse'
524+
)
525+
else:
526+
builder.add_unary(
527+
name=name,
528+
input_name=input_names[0],
529+
output_name=output_names[0],
530+
mode='power',
531+
alpha=float(p)
532+
)
533+
return output_names
534+
535+
441536
_TORCH_LAYER_REGISTRY = {
442537
'Sequential': _convert_sequential,
443538
'SpatialConvolution': _convert_convolution,
@@ -460,7 +555,13 @@ def _convert_split_table(builder, name, layer, input_names, output_names):
460555
'Narrow': _convert_narrow,
461556
'SpatialReflectionPadding': _convert_reflection_padding,
462557
'SpatialUpSamplingNearest': _convert_upsampling_nearest,
463-
'SplitTable': _convert_split_table
558+
'SplitTable': _convert_split_table,
559+
'CDivTable': _convert_cdiv_table,
560+
'Log': _convert_log,
561+
'Sigmoid': _convert_sigmoid,
562+
'ParallelTable': _convert_parallel_table,
563+
'Power': _convert_power,
564+
'CMulTable': _convert_cmul_table,
464565
}
465566

466567

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.