1
1
#ifndef UTENSOR_TRANSPOSE_H
2
2
#define UTENSOR_TRANSPOSE_H
3
3
4
+ #include < cstring>
5
+
4
6
#include " context.hpp"
5
- #include " types .hpp"
7
+ #include " operatorBase .hpp"
6
8
#include " tensor.hpp"
9
+ #include " types.hpp"
7
10
#include " uTensor_util.hpp"
8
- #include " operatorBase.hpp"
9
-
10
- #include < cstring>
11
11
12
12
namespace uTensor {
13
13
namespace ReferenceOperators {
14
14
15
15
// Transpose (Swap Axes) as a port from Numpy
16
16
// using stride interation in the order of transpose axes
17
17
template <typename Tin>
18
- class TransposeOperator : public OperatorInterface <1 , 1 > {
19
- /* reshape input as the shape of output*/
20
- public:
21
- TransposeOperator (const TensorShape&& axes) : _axes(axes) {}
22
- TransposeOperator (const TensorShape& axes) : _axes(axes) {}
23
-
24
- enum names_in : uint8_t { input };
18
+ class TransposeOperator : public OperatorInterface <2 , 1 > {
19
+ /* reshape input as the shape of output*/
20
+ public:
21
+ enum names_in : uint8_t { input, perm };
25
22
enum names_out : uint8_t { output };
26
23
27
- virtual void compute (){
24
+ virtual void compute () {
25
+ Tensor& perm_tensor = inputs[perm].tensor ();
26
+ if (perm_tensor.get_shape ().num_dims () > 1 ) {
27
+ uTensor_printf (
28
+ " the input tensor perm should be a vector (dimension should be 1)\n " );
29
+ Context::get_default_context ()->throwError (new InvalidTensorInputError);
30
+ }
31
+ if (perm_tensor->get_type () != i32) {
32
+ uTensor_printf (" expecting perm tensor of element type int32_t\n " );
33
+ Context::get_default_context ()->throwError (
34
+ new InvalidTensorDataTypeError);
35
+ }
28
36
Tensor& input_tensor = inputs[input].tensor ();
29
37
TensorShape& input_shape = input_tensor.get_shape ();
30
38
input_shape.update_dims ();
@@ -36,78 +44,80 @@ class TransposeOperator : public OperatorInterface<1, 1> {
36
44
Tensor& output_tensor = outputs[output].tensor ();
37
45
38
46
// Create a placeholder to calculate the output shape
39
- // Normally this would reference output shape, but since this could (usually would) be referencing the input, let's keep a dedicated value
40
- TensorShape output_shape = TensorShape (1 ,1 ,1 ,1 );
47
+ // Normally this would reference output shape, but since this could (usually
48
+ // would) be referencing the input, let's keep a dedicated value
49
+ TensorShape output_shape = TensorShape (1 , 1 , 1 , 1 );
41
50
TensorStrides output_strides = TensorStrides (output_shape);
42
51
TensorShape offsets = TensorShape (input_shape.num_dims ());
43
52
44
- for (size_t i = 0 ; i < 4 ; ++i) {
53
+ for (size_t i = 0 ; i < 4 ; ++i) {
45
54
output_shape[i] = 0 ;
46
55
output_strides[i] = 0 ;
47
56
48
57
// Offsets are used to avoid multiple for loops
49
58
offsets[i] = 0 ;
50
59
}
51
60
52
- for (size_t i = 0 ; i < (size_t ) input_shape.num_dims (); ++i) {
53
- output_shape[_axes[i]] = input_shape[i];
61
+ for (size_t i = 0 ; i < (size_t )input_shape.num_dims (); ++i) {
62
+ int32_t axis = static_cast <int32_t >(perm_tensor (i));
63
+ output_shape[axis] = input_shape[i];
54
64
55
65
// output_strides(i) is derived from axes and input_strides
56
- output_strides[_axes[i] ] = input_strides[i];
66
+ output_strides[axis ] = input_strides[i];
57
67
}
58
-
59
- // Output shape can be asserted once the transform
68
+
69
+ // Output shape can be asserted once the transform
60
70
// effect has been determined
61
71
output_shape.update_dims ();
62
72
output_tensor->resize (output_shape);
63
73
64
74
// Perform some basic checks
65
- if (input_tensor->num_elems () != output_tensor->num_elems ()){
66
- uTensor_printf (" inconsistent input and output shape for reshape\n " );
67
- Context::get_default_context ()->throwError (new InvalidReshapeError);
68
- return ;
69
- }
70
- if (input_tensor->get_type () != output_tensor->get_type ()){
71
- uTensor_printf (" inconsistent input and output data type for reshape\n " );
72
- Context::get_default_context ()->throwError (new InvalidTensorDataTypeError);
73
- return ;
75
+ if (input_tensor->num_elems () != output_tensor->num_elems ()) {
76
+ uTensor_printf (" inconsistent input and output shape for reshape\n " );
77
+ Context::get_default_context ()->throwError (new InvalidReshapeError);
78
+ return ;
79
+ }
80
+ if (input_tensor->get_type () != output_tensor->get_type ()) {
81
+ uTensor_printf (" inconsistent input and output data type for reshape\n " );
82
+ Context::get_default_context ()->throwError (
83
+ new InvalidTensorDataTypeError);
84
+ return ;
74
85
}
75
- if (!_check_input_shape ()){
76
- Context::get_default_context ()->throwError (new InvalidTensorDataTypeError);
77
- return ;
86
+ if (!_check_input_shape ()) {
87
+ Context::get_default_context ()->throwError (
88
+ new InvalidTensorDataTypeError);
89
+ return ;
78
90
}
79
91
80
92
// copy data
81
- for (uint32_t i = 0 ; i < input_tensor->num_elems (); ++i) {
82
- // Index of the source value, must be calculated
83
- // using the output strides and output shape
84
- uint32_t idx = 0 ;
85
- for (uint32_t j = 0 ; j < output_shape.num_dims (); j++) {
86
- idx += offsets[j] * output_strides[j];
87
- }
88
-
89
- // this is not copy: `output_tensor(i) = input_tensor(i);`
90
- output_tensor (i) = static_cast <Tin>(input_tensor (idx));
93
+ for (uint32_t i = 0 ; i < input_tensor->num_elems (); ++i) {
94
+ // Index of the source value, must be calculated
95
+ // using the output strides and output shape
96
+ uint32_t idx = 0 ;
97
+ for (uint32_t j = 0 ; j < output_shape.num_dims (); j++) {
98
+ idx += offsets[j] * output_strides[j];
99
+ }
91
100
92
- // Update offsets, to iterate sequentially along strides
93
- // in the order of axes
94
- for (int32_t j = output_shape.num_dims () - 1 ; j >= 0 ; j--) {
95
- offsets[j] = (offsets[j] + 1 ) % (output_shape[j]);
96
- if ( offsets[j] > 0 ) {
97
- break ;
98
- }
99
- }
100
- }
101
+ // this is not copy: `output_tensor(i) = input_tensor(i);`
102
+ output_tensor (i) = static_cast <Tin>(input_tensor (idx));
101
103
104
+ // Update offsets, to iterate sequentially along strides
105
+ // in the order of axes
106
+ for (int32_t j = output_shape.num_dims () - 1 ; j >= 0 ; j--) {
107
+ offsets[j] = (offsets[j] + 1 ) % (output_shape[j]);
108
+ if (offsets[j] > 0 ) {
109
+ break ;
110
+ }
111
+ }
112
+ }
102
113
}
103
- private:
104
- TensorShape _axes;
105
114
106
- bool _check_input_shape (){
115
+ private:
116
+ bool _check_input_shape () {
107
117
const Tensor& input_tensor = inputs[input].tensor ();
108
118
const TensorShape& shape = input_tensor->get_shape ();
109
119
uint8_t num_dims = shape.num_dims ();
110
- for (int i = 0 ; i < num_dims; ++i){
120
+ for (int i = 0 ; i < num_dims; ++i) {
111
121
if (shape[i] < 0 ) {
112
122
uTensor_printf (" the output shape must be all positive\n " );
113
123
return false ;
@@ -117,7 +127,7 @@ class TransposeOperator : public OperatorInterface<1, 1> {
117
127
}
118
128
};
119
129
120
- }
121
- }
130
+ } // namespace ReferenceOperators
131
+ } // namespace uTensor
122
132
123
- #endif // UTENSOR_TRANSPOSE_H
133
+ #endif // UTENSOR_TRANSPOSE_H
0 commit comments