Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bfa4cb6

Browse filesBrowse files
committed
tranpose op: perm as input tensor
1 parent 45c1509 commit bfa4cb6
Copy full SHA for bfa4cb6

File tree

Expand file treeCollapse file tree

3 files changed

+88
-68
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+88
-68
lines changed
+14-4Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
#ifndef _TRANSPOSE_TEST_H
22
#define _TRANSPOSE_TEST_H
33

4-
static const unsigned short transpose_axes_arr[3] = { 2,1,0 };
5-
static const float random_input_arr[15] = { 3.484638214111328, 2.033799886703491, 3.2437448501586914, 4.783249855041504, 3.497023582458496, 3.511240005493164, 1.558927297592163, 3.7084484100341797, 2.570117712020874, 0.2405869960784912, 1.8713605403900146, 4.19132661819458, 0.6596618890762329, 0.9029078483581543, 0.2223271131515503 };
6-
static const float ref_output_arr[15] = { 3.484638214111328, 3.511240005493164, 1.8713605403900146, 2.033799886703491, 1.558927297592163, 4.19132661819458, 3.2437448501586914, 3.7084484100341797, 0.6596618890762329, 4.783249855041504, 2.570117712020874, 0.9029078483581543, 3.497023582458496, 0.2405869960784912, 0.2223271131515503 };
4+
static const int32_t transpose_perm_arr[4] = {2, 1, 0, 3};
5+
static const float random_input_arr[15] = {
6+
3.484638214111328, 2.033799886703491, 3.2437448501586914,
7+
4.783249855041504, 3.497023582458496, 3.511240005493164,
8+
1.558927297592163, 3.7084484100341797, 2.570117712020874,
9+
0.2405869960784912, 1.8713605403900146, 4.19132661819458,
10+
0.6596618890762329, 0.9029078483581543, 0.2223271131515503};
11+
static const float ref_output_arr[15] = {
12+
3.484638214111328, 3.511240005493164, 1.8713605403900146,
13+
2.033799886703491, 1.558927297592163, 4.19132661819458,
14+
3.2437448501586914, 3.7084484100341797, 0.6596618890762329,
15+
4.783249855041504, 2.570117712020874, 0.9029078483581543,
16+
3.497023582458496, 0.2405869960784912, 0.2223271131515503};
717

8-
#endif // _TRANSPOSE
18+
#endif // _TRANSPOSE_TEST_H

‎TESTS/operators/test_transpose.cpp

Copy file name to clipboardExpand all lines: TESTS/operators/test_transpose.cpp
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#include <iostream>
33

44
#include "RamTensor.hpp"
5-
#include "Transpose.hpp"
65
#include "RomTensor.hpp"
6+
#include "Transpose.hpp"
77
#include "arenaAllocator.hpp"
88
#include "constants_transpose.hpp"
99
#include "context.hpp"
@@ -19,19 +19,19 @@ TEST(Transpose, transpose_test) {
1919
localCircularArenaAllocator<15 * 2 * sizeof(float), uint32_t> ram_allocator;
2020
Context::get_default_context()->set_metadata_allocator(&meta_allocator);
2121
Context::get_default_context()->set_ram_data_allocator(&ram_allocator);
22-
22+
2323
Tensor input_tensor = new RomTensor({3, 1, 5, 1}, flt, random_input_arr);
24+
Tensor perm_tensor = new RomTensor({4}, i32, transpose_perm_arr);
2425

2526
TensorShape input_target_shape(3, 1, 5, 1);
2627
TensorShape input_shape = input_tensor->get_shape();
2728
EXPECT_TRUE(input_target_shape == input_shape);
2829

29-
Tensor transpose_axes = new RomTensor({4}, u8, transpose_axes_arr);
3030
Tensor output_tensor = new RamTensor(flt);
31-
TransposeOperator<float> op({2,1,0,3});
32-
31+
TransposeOperator<float> op;
3332

34-
op.set_inputs({{TransposeOperator<float>::input, input_tensor}})
33+
op.set_inputs({{TransposeOperator<float>::input, input_tensor},
34+
{TransposeOperator<float>::perm, perm_tensor}})
3535
.set_outputs({{TransposeOperator<float>::output, output_tensor}})
3636
.eval();
3737

‎src/uTensor/ops/Transpose.hpp

Copy file name to clipboard
+68-58Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,38 @@
11
#ifndef UTENSOR_TRANSPOSE_H
22
#define UTENSOR_TRANSPOSE_H
33

4+
#include <cstring>
5+
46
#include "context.hpp"
5-
#include "types.hpp"
7+
#include "operatorBase.hpp"
68
#include "tensor.hpp"
9+
#include "types.hpp"
710
#include "uTensor_util.hpp"
8-
#include "operatorBase.hpp"
9-
10-
#include <cstring>
1111

1212
namespace uTensor {
1313
namespace ReferenceOperators {
1414

1515
// Transpose (Swap Axes) as a port from Numpy
1616
// using stride interation in the order of transpose axes
1717
template <typename Tin>
18-
class TransposeOperator : public OperatorInterface<1, 1> {
19-
/* reshape input as the shape of output*/
20-
public:
21-
TransposeOperator(const TensorShape&& axes) : _axes(axes) {}
22-
TransposeOperator(const TensorShape& axes) : _axes(axes) {}
23-
24-
enum names_in : uint8_t { input };
18+
class TransposeOperator : public OperatorInterface<2, 1> {
19+
/* reshape input as the shape of output*/
20+
public:
21+
enum names_in : uint8_t { input, perm };
2522
enum names_out : uint8_t { output };
2623

27-
virtual void compute(){
24+
virtual void compute() {
25+
Tensor& perm_tensor = inputs[perm].tensor();
26+
if (perm_tensor.get_shape().num_dims() > 1) {
27+
uTensor_printf(
28+
"the input tensor perm should be a vector (dimension should be 1)\n");
29+
Context::get_default_context()->throwError(new InvalidTensorInputError);
30+
}
31+
if (perm_tensor->get_type() != i32) {
32+
uTensor_printf("expecting perm tensor of element type int32_t\n");
33+
Context::get_default_context()->throwError(
34+
new InvalidTensorDataTypeError);
35+
}
2836
Tensor& input_tensor = inputs[input].tensor();
2937
TensorShape& input_shape = input_tensor.get_shape();
3038
input_shape.update_dims();
@@ -36,78 +44,80 @@ class TransposeOperator : public OperatorInterface<1, 1> {
3644
Tensor& output_tensor = outputs[output].tensor();
3745

3846
// Create a placeholder to calculate the output shape
39-
// Normally this would reference output shape, but since this could (usually would) be referencing the input, let's keep a dedicated value
40-
TensorShape output_shape = TensorShape(1,1,1,1);
47+
// Normally this would reference output shape, but since this could (usually
48+
// would) be referencing the input, let's keep a dedicated value
49+
TensorShape output_shape = TensorShape(1, 1, 1, 1);
4150
TensorStrides output_strides = TensorStrides(output_shape);
4251
TensorShape offsets = TensorShape(input_shape.num_dims());
4352

44-
for (size_t i = 0; i < 4; ++i) {
53+
for (size_t i = 0; i < 4; ++i) {
4554
output_shape[i] = 0;
4655
output_strides[i] = 0;
4756

4857
// Offsets are used to avoid multiple for loops
4958
offsets[i] = 0;
5059
}
5160

52-
for (size_t i = 0; i < (size_t) input_shape.num_dims(); ++i) {
53-
output_shape[_axes[i]] = input_shape[i];
61+
for (size_t i = 0; i < (size_t)input_shape.num_dims(); ++i) {
62+
int32_t axis = static_cast<int32_t>(perm_tensor(i));
63+
output_shape[axis] = input_shape[i];
5464

5565
// output_strides(i) is derived from axes and input_strides
56-
output_strides[_axes[i]] = input_strides[i];
66+
output_strides[axis] = input_strides[i];
5767
}
58-
59-
// Output shape can be asserted once the transform
68+
69+
// Output shape can be asserted once the transform
6070
// effect has been determined
6171
output_shape.update_dims();
6272
output_tensor->resize(output_shape);
6373

6474
// Perform some basic checks
65-
if (input_tensor->num_elems() != output_tensor->num_elems()){
66-
uTensor_printf("inconsistent input and output shape for reshape\n");
67-
Context::get_default_context()->throwError(new InvalidReshapeError);
68-
return;
69-
}
70-
if (input_tensor->get_type() != output_tensor->get_type()){
71-
uTensor_printf("inconsistent input and output data type for reshape\n");
72-
Context::get_default_context()->throwError(new InvalidTensorDataTypeError);
73-
return;
75+
if (input_tensor->num_elems() != output_tensor->num_elems()) {
76+
uTensor_printf("inconsistent input and output shape for reshape\n");
77+
Context::get_default_context()->throwError(new InvalidReshapeError);
78+
return;
79+
}
80+
if (input_tensor->get_type() != output_tensor->get_type()) {
81+
uTensor_printf("inconsistent input and output data type for reshape\n");
82+
Context::get_default_context()->throwError(
83+
new InvalidTensorDataTypeError);
84+
return;
7485
}
75-
if (!_check_input_shape()){
76-
Context::get_default_context()->throwError(new InvalidTensorDataTypeError);
77-
return;
86+
if (!_check_input_shape()) {
87+
Context::get_default_context()->throwError(
88+
new InvalidTensorDataTypeError);
89+
return;
7890
}
7991

8092
// copy data
81-
for (uint32_t i = 0; i < input_tensor->num_elems(); ++i) {
82-
// Index of the source value, must be calculated
83-
// using the output strides and output shape
84-
uint32_t idx = 0;
85-
for (uint32_t j = 0; j < output_shape.num_dims(); j++) {
86-
idx += offsets[j] * output_strides[j];
87-
}
88-
89-
// this is not copy: `output_tensor(i) = input_tensor(i);`
90-
output_tensor(i) = static_cast<Tin>(input_tensor(idx));
93+
for (uint32_t i = 0; i < input_tensor->num_elems(); ++i) {
94+
// Index of the source value, must be calculated
95+
// using the output strides and output shape
96+
uint32_t idx = 0;
97+
for (uint32_t j = 0; j < output_shape.num_dims(); j++) {
98+
idx += offsets[j] * output_strides[j];
99+
}
91100

92-
// Update offsets, to iterate sequentially along strides
93-
// in the order of axes
94-
for (int32_t j = output_shape.num_dims() - 1; j >= 0; j--) {
95-
offsets[j] = (offsets[j] + 1) % (output_shape[j]);
96-
if( offsets[j] > 0 ) {
97-
break;
98-
}
99-
}
100-
}
101+
// this is not copy: `output_tensor(i) = input_tensor(i);`
102+
output_tensor(i) = static_cast<Tin>(input_tensor(idx));
101103

104+
// Update offsets, to iterate sequentially along strides
105+
// in the order of axes
106+
for (int32_t j = output_shape.num_dims() - 1; j >= 0; j--) {
107+
offsets[j] = (offsets[j] + 1) % (output_shape[j]);
108+
if (offsets[j] > 0) {
109+
break;
110+
}
111+
}
112+
}
102113
}
103-
private:
104-
TensorShape _axes;
105114

106-
bool _check_input_shape(){
115+
private:
116+
bool _check_input_shape() {
107117
const Tensor& input_tensor = inputs[input].tensor();
108118
const TensorShape& shape = input_tensor->get_shape();
109119
uint8_t num_dims = shape.num_dims();
110-
for (int i = 0; i < num_dims; ++i){
120+
for (int i = 0; i < num_dims; ++i) {
111121
if (shape[i] < 0) {
112122
uTensor_printf("the output shape must be all positive\n");
113123
return false;
@@ -117,7 +127,7 @@ class TransposeOperator : public OperatorInterface<1, 1> {
117127
}
118128
};
119129

120-
}
121-
}
130+
} // namespace ReferenceOperators
131+
} // namespace uTensor
122132

123-
#endif // UTENSOR_TRANSPOSE_H
133+
#endif // UTENSOR_TRANSPOSE_H

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.