Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 63875fa

Browse filesBrowse files
authored
Fix default values (#63)
1 parent c7375ca commit 63875fa
Copy full SHA for 63875fa

File tree

1 file changed

+81
-82
lines changed
Filter options

1 file changed

+81
-82
lines changed

‎onnx_array_api/light_api/_op_vars.py

Copy file name to clipboardExpand all lines: onnx_array_api/light_api/_op_vars.py
+81-82Lines changed: 81 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ def BitShift(self, direction: str = "") -> "Var":
1010
return self.make_node("BitShift", *self.vars_, direction=direction)
1111

1212
def CenterCropPad(self, axes: Optional[List[int]] = None) -> "Var":
13-
axes = axes or []
14-
return self.make_node("CenterCropPad", *self.vars_, axes=axes)
13+
kwargs = {}
14+
if axes is not None:
15+
kwargs["axes"] = axes
16+
return self.make_node("CenterCropPad", *self.vars_, **kwargs)
1517

1618
def Clip(
1719
self,
@@ -27,12 +29,14 @@ def Col2Im(
2729
pads: Optional[List[int]] = None,
2830
strides: Optional[List[int]] = None,
2931
) -> "Var":
30-
dilations = dilations or []
31-
pads = pads or []
32-
strides = strides or []
33-
return self.make_node(
34-
"Col2Im", *self.vars_, dilations=dilations, pads=pads, strides=strides
35-
)
32+
kwargs = {}
33+
if dilations is not None:
34+
kwargs["dilations"] = dilations
35+
if pads is not None:
36+
kwargs["pads"] = pads
37+
if strides is not None:
38+
kwargs["strides"] = strides
39+
return self.make_node("Col2Im", *self.vars_, **kwargs)
3640

3741
def Compress(self, axis: int = 0) -> "Var":
3842
return self.make_node("Compress", *self.vars_, axis=axis)
@@ -71,19 +75,17 @@ def ConvInteger(
7175
pads: Optional[List[int]] = None,
7276
strides: Optional[List[int]] = None,
7377
) -> "Var":
74-
dilations = dilations or []
75-
kernel_shape = kernel_shape or []
76-
pads = pads or []
77-
strides = strides or []
78+
kwargs = {}
79+
if dilations is not None:
80+
kwargs["dilations"] = dilations
81+
if kernel_shape is not None:
82+
kwargs["kernel_shape"] = kernel_shape
83+
if pads is not None:
84+
kwargs["pads"] = pads
85+
if strides is not None:
86+
kwargs["strides"] = strides
7887
return self.make_node(
79-
"ConvInteger",
80-
*self.vars_,
81-
auto_pad=auto_pad,
82-
dilations=dilations,
83-
group=group,
84-
kernel_shape=kernel_shape,
85-
pads=pads,
86-
strides=strides,
88+
"ConvInteger", *self.vars_, auto_pad=auto_pad, group=group, **kwargs
8789
)
8890

8991
def ConvTranspose(
@@ -97,23 +99,21 @@ def ConvTranspose(
9799
pads: Optional[List[int]] = None,
98100
strides: Optional[List[int]] = None,
99101
) -> "Var":
100-
dilations = dilations or []
101-
kernel_shape = kernel_shape or []
102-
output_padding = output_padding or []
103-
output_shape = output_shape or []
104-
pads = pads or []
105-
strides = strides or []
106-
return self.make_node(
107-
"ConvTranspose",
108-
*self.vars_,
109-
auto_pad=auto_pad,
110-
dilations=dilations,
111-
group=group,
112-
kernel_shape=kernel_shape,
113-
output_padding=output_padding,
114-
output_shape=output_shape,
115-
pads=pads,
116-
strides=strides,
102+
kwargs = {}
103+
if dilations is not None:
104+
kwargs["dilations"] = dilations
105+
if kernel_shape is not None:
106+
kwargs["kernel_shape"] = kernel_shape
107+
if pads is not None:
108+
kwargs["pads"] = pads
109+
if strides is not None:
110+
kwargs["strides"] = strides
111+
if output_padding is not None:
112+
kwargs["output_padding"] = output_padding
113+
if output_shape is not None:
114+
kwargs["output_shape"] = output_shape
115+
return self.make_node(
116+
"ConvTranspose", *self.vars_, auto_pad=auto_pad, group=group, **kwargs
117117
)
118118

119119
def CumSum(self, exclusive: int = 0, reverse: int = 0) -> "Var":
@@ -135,19 +135,17 @@ def DeformConv(
135135
pads: Optional[List[int]] = None,
136136
strides: Optional[List[int]] = None,
137137
) -> "Var":
138-
dilations = dilations or []
139-
kernel_shape = kernel_shape or []
140-
pads = pads or []
141-
strides = strides or []
138+
kwargs = {}
139+
if dilations is not None:
140+
kwargs["dilations"] = dilations
141+
if kernel_shape is not None:
142+
kwargs["kernel_shape"] = kernel_shape
143+
if pads is not None:
144+
kwargs["pads"] = pads
145+
if strides is not None:
146+
kwargs["strides"] = strides
142147
return self.make_node(
143-
"DeformConv",
144-
*self.vars_,
145-
dilations=dilations,
146-
group=group,
147-
kernel_shape=kernel_shape,
148-
offset_group=offset_group,
149-
pads=pads,
150-
strides=strides,
148+
"DeformConv", *self.vars_, group=group, offset_group=offset_group, **kwargs
151149
)
152150

153151
def DequantizeLinear(self, axis: int = 1) -> "Var":
@@ -204,12 +202,11 @@ def MatMulInteger(
204202
def MaxRoiPool(
205203
self, pooled_shape: Optional[List[int]] = None, spatial_scale: float = 1.0
206204
) -> "Var":
207-
pooled_shape = pooled_shape or []
205+
kwargs = {}
206+
if pooled_shape is not None:
207+
kwargs["pooled_shape"] = pooled_shape
208208
return self.make_node(
209-
"MaxRoiPool",
210-
*self.vars_,
211-
pooled_shape=pooled_shape,
212-
spatial_scale=spatial_scale,
209+
"MaxRoiPool", *self.vars_, spatial_scale=spatial_scale, **kwargs
213210
)
214211

215212
def MaxUnpool(
@@ -218,16 +215,14 @@ def MaxUnpool(
218215
pads: Optional[List[int]] = None,
219216
strides: Optional[List[int]] = None,
220217
) -> "Var":
221-
kernel_shape = kernel_shape or []
222-
pads = pads or []
223-
strides = strides or []
224-
return self.make_node(
225-
"MaxUnpool",
226-
*self.vars_,
227-
kernel_shape=kernel_shape,
228-
pads=pads,
229-
strides=strides,
230-
)
218+
kwargs = {}
219+
if kernel_shape is not None:
220+
kwargs["kernel_shape"] = kernel_shape
221+
if pads is not None:
222+
kwargs["pads"] = pads
223+
if strides is not None:
224+
kwargs["strides"] = strides
225+
return self.make_node("MaxUnpool", *self.vars_, **kwargs)
231226

232227
def MelWeightMatrix(self, output_datatype: int = 1) -> "Var":
233228
return self.make_node(
@@ -267,19 +262,17 @@ def QLinearConv(
267262
pads: Optional[List[int]] = None,
268263
strides: Optional[List[int]] = None,
269264
) -> "Var":
270-
dilations = dilations or []
271-
kernel_shape = kernel_shape or []
272-
pads = pads or []
273-
strides = strides or []
265+
kwargs = {}
266+
if kernel_shape is not None:
267+
kwargs["kernel_shape"] = kernel_shape
268+
if pads is not None:
269+
kwargs["pads"] = pads
270+
if strides is not None:
271+
kwargs["strides"] = strides
272+
if dilations is not None:
273+
kwargs["dilations"] = dilations
274274
return self.make_node(
275-
"QLinearConv",
276-
*self.vars_,
277-
auto_pad=auto_pad,
278-
dilations=dilations,
279-
group=group,
280-
kernel_shape=kernel_shape,
281-
pads=pads,
282-
strides=strides,
275+
"QLinearConv", *self.vars_, auto_pad=auto_pad, group=group, **kwargs
283276
)
284277

285278
def QLinearMatMul(
@@ -303,15 +296,17 @@ def RandomNormal(
303296
seed: float = 0.0,
304297
shape: Optional[List[int]] = None,
305298
) -> "Var":
306-
shape = shape or []
299+
kwargs = {}
300+
if shape is not None:
301+
kwargs["shape"] = shape
307302
return self.make_node(
308303
"RandomNormal",
309304
*self.vars_,
310305
dtype=dtype,
311306
mean=mean,
312307
scale=scale,
313308
seed=seed,
314-
shape=shape,
309+
**kwargs,
315310
)
316311

317312
def RandomUniform(
@@ -322,15 +317,17 @@ def RandomUniform(
322317
seed: float = 0.0,
323318
shape: Optional[List[int]] = None,
324319
) -> "Var":
325-
shape = shape or []
320+
kwargs = {}
321+
if shape is not None:
322+
kwargs["shape"] = shape
326323
return self.make_node(
327324
"RandomUniform",
328325
*self.vars_,
329326
dtype=dtype,
330327
high=high,
331328
low=low,
332329
seed=seed,
333-
shape=shape,
330+
**kwargs,
334331
)
335332

336333
def Range(
@@ -437,19 +434,21 @@ def Resize(
437434
mode: str = "nearest",
438435
nearest_mode: str = "round_prefer_floor",
439436
) -> "Var":
440-
axes = axes or []
437+
kwargs = {}
438+
if axes is not None:
439+
kwargs["axes"] = axes
441440
return self.make_node(
442441
"Resize",
443442
*self.vars_,
444443
antialias=antialias,
445-
axes=axes,
446444
coordinate_transformation_mode=coordinate_transformation_mode,
447445
cubic_coeff_a=cubic_coeff_a,
448446
exclude_outside=exclude_outside,
449447
extrapolation_value=extrapolation_value,
450448
keep_aspect_ratio_policy=keep_aspect_ratio_policy,
451449
mode=mode,
452450
nearest_mode=nearest_mode,
451+
**kwargs,
453452
)
454453

455454
def RoiAlign(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.