Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5057d33

Browse filesBrowse files
committed
improve comments based on review. Remove 'negative' functions and logic
1 parent 0046727 commit 5057d33
Copy full SHA for 5057d33

File tree

Expand file treeCollapse file tree

2 files changed

+32
-56
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+32
-56
lines changed

‎mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+31-50Lines changed: 31 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -444,59 +444,40 @@ struct LinearizeVectorSplat final
444444

445445
} // namespace
446446

447-
/// Some operations currently will not be linearized if they have scalable
448-
/// vector results, although support should be added in the future. This
449-
/// function returns true if `op` is such an operation.
450-
static bool isNotLinearizableBecauseScalable(Operation *op) {
451-
452-
bool unsupported =
453-
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
454-
op);
455-
456-
// Case where linearization is possible even when there are scalable vector
457-
// results.
458-
if (!unsupported)
459-
return false;
460-
461-
// Check if any of the results is a scalable vector type, and if there are
462-
// return true (not linearizable).
463-
auto types = op->getResultTypes();
464-
bool containsScalableResult =
465-
std::any_of(types.begin(), types.end(), [](Type type) {
466-
auto vecType = dyn_cast<VectorType>(type);
467-
return vecType && vecType.isScalable();
468-
});
469-
470-
return containsScalableResult;
471-
}
472-
473-
/// This method defines a set of operations that are not linearizable, and hence
474-
/// they are considered legal for the conversion target. These ops are
475-
/// currently,
476-
///
477-
/// 1) ones that are not in the vector dialect, are not ConstantLike, and are
478-
/// not Vectorizable, or
479-
///
480-
/// 2) have scalable vector results, for which support has not yet been added.
481-
static bool isNotLinearizable(Operation *op) {
447+
/// This method defines the set of operations that are linearizable, and hence
448+
/// that are considered illegal for the conversion target.
449+
static bool isLinearizable(Operation *op) {
482450

451+
// Operations such as builtin.module are not linearizable.
483452
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
484453
StringRef opDialect = op->getDialect()->getNamespace();
485-
bool unsupported = (opDialect != vectorDialect) &&
486-
!op->hasTrait<OpTrait::ConstantLike>() &&
487-
!op->hasTrait<OpTrait::Vectorizable>();
488-
if (unsupported)
489-
return true;
490-
491-
// vector.shape_cast cannot be linearized.
492-
if (isa<vector::ShapeCastOp>(op))
493-
return true;
494-
495-
// Some ops currently don't support scalable vectors.
496-
if (isNotLinearizableBecauseScalable(op))
497-
return true;
454+
bool supported = (opDialect == vectorDialect) ||
455+
op->hasTrait<OpTrait::ConstantLike>() ||
456+
op->hasTrait<OpTrait::Vectorizable>();
457+
if (!supported)
458+
return false;
498459

499-
return false;
460+
return TypeSwitch<Operation *, bool>(op)
461+
// As type legalization is done with vector.shape_cast, shape_cast
462+
// itself cannot be linearized (will create new shape_casts to linearize
463+
// ad infinitum).
464+
.Case<vector::ShapeCastOp>([&](auto) { return false; })
465+
// vector.extract_strided_slice, vector.extract, and vector.insert
466+
// operations are linearized to a rank-1 vector.shuffle by the current
467+
// patterns. vector.shuffle only supports fixed size vectors, so it is
468+
// impossible to use this approach to linearize these ops if they operate
469+
// on scalable vectors.
470+
.Case<vector::ExtractStridedSliceOp>(
471+
[&](vector::ExtractStridedSliceOp extractOp) {
472+
return !extractOp.getType().isScalable();
473+
})
474+
.Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
475+
return !insertOp.getType().isScalable();
476+
})
477+
.Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
478+
return !extractOp.getSourceVectorType().isScalable();
479+
})
480+
.Default([&](auto) { return true; });
500481
}
501482

502483
void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
@@ -530,7 +511,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
530511

531512
target.markUnknownOpDynamicallyLegal(
532513
[=](Operation *op) -> std::optional<bool> {
533-
if (isNotLinearizable(op))
514+
if (!isLinearizable(op))
534515
return true;
535516
// This will return true if, for all operand and result types `t`,
536517
// convertType(t) = t. This is true if there are no rank>=2 vectors.

‎mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Copy file name to clipboardExpand all lines: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+1-6Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final
837837
}
838838
};
839839

840-
namespace bit_width_constrained_linearization {
841-
842840
/// Get the set of operand/result types to check for sufficiently
843841
/// small inner-most dimension size.
844842
static SmallVector<std::pair<Type, unsigned>>
@@ -960,8 +958,6 @@ struct TestVectorBitWidthLinearize final
960958
}
961959
};
962960

963-
} // namespace bit_width_constrained_linearization
964-
965961
struct TestVectorLinearize final
966962
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
967963
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
@@ -1069,8 +1065,7 @@ void registerTestVectorLowerings() {
10691065

10701066
PassRegistration<TestVectorLinearize>();
10711067

1072-
PassRegistration<
1073-
bit_width_constrained_linearization::TestVectorBitWidthLinearize>();
1068+
PassRegistration<TestVectorBitWidthLinearize>();
10741069

10751070
PassRegistration<TestEliminateVectorMasks>();
10761071
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.