@@ -444,59 +444,40 @@ struct LinearizeVectorSplat final
444
444
445
445
} // namespace
446
446
447
- // / Some operations currently will not be linearized if they have scalable
448
- // / vector results, although support should be added in the future. This
449
- // / function returns true if `op` is such an operation.
450
- static bool isNotLinearizableBecauseScalable (Operation *op) {
451
-
452
- bool unsupported =
453
- isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
454
- op);
455
-
456
- // Case where linearization is possible even when there are scalable vector
457
- // results.
458
- if (!unsupported)
459
- return false ;
460
-
461
- // Check if any of the results is a scalable vector type, and if there are
462
- // return true (not linearizable).
463
- auto types = op->getResultTypes ();
464
- bool containsScalableResult =
465
- std::any_of (types.begin (), types.end (), [](Type type) {
466
- auto vecType = dyn_cast<VectorType>(type);
467
- return vecType && vecType.isScalable ();
468
- });
469
-
470
- return containsScalableResult;
471
- }
472
-
473
- // / This method defines a set of operations that are not linearizable, and hence
474
- // / they are considered legal for the conversion target. These ops are
475
- // / currently,
476
- // /
477
- // / 1) ones that are not in the vector dialect, are not ConstantLike, and are
478
- // / not Vectorizable, or
479
- // /
480
- // / 2) have scalable vector results, for which support has not yet been added.
481
- static bool isNotLinearizable (Operation *op) {
447
+ // / This method defines the set of operations that are linearizable, and hence
448
+ // / that are considered illegal for the conversion target.
449
+ static bool isLinearizable (Operation *op) {
482
450
451
+ // Operations such as builtin.module are not linearizable.
483
452
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace ();
484
453
StringRef opDialect = op->getDialect ()->getNamespace ();
485
- bool unsupported = (opDialect != vectorDialect) &&
486
- !op->hasTrait <OpTrait::ConstantLike>() &&
487
- !op->hasTrait <OpTrait::Vectorizable>();
488
- if (unsupported)
489
- return true ;
490
-
491
- // vector.shape_cast cannot be linearized.
492
- if (isa<vector::ShapeCastOp>(op))
493
- return true ;
494
-
495
- // Some ops currently don't support scalable vectors.
496
- if (isNotLinearizableBecauseScalable (op))
497
- return true ;
454
+ bool supported = (opDialect == vectorDialect) ||
455
+ op->hasTrait <OpTrait::ConstantLike>() ||
456
+ op->hasTrait <OpTrait::Vectorizable>();
457
+ if (!supported)
458
+ return false ;
498
459
499
- return false ;
460
+ return TypeSwitch<Operation *, bool >(op)
461
+ // As type legalization is done with vector.shape_cast, shape_cast
462
+ // itself cannot be linearized (will create new shape_casts to linearize
463
+ // ad infinitum).
464
+ .Case <vector::ShapeCastOp>([&](auto ) { return false ; })
465
+ // vector.extract_strided_slice, vector.extract, and vector.insert
466
+ // operations are linearized to a rank-1 vector.shuffle by the current
467
+ // patterns. vector.shuffle only supports fixed size vectors, so it is
468
+ // impossible to use this approach to linearize these ops if they operate
469
+ // on scalable vectors.
470
+ .Case <vector::ExtractStridedSliceOp>(
471
+ [&](vector::ExtractStridedSliceOp extractOp) {
472
+ return !extractOp.getType ().isScalable ();
473
+ })
474
+ .Case <vector::InsertOp>([&](vector::InsertOp insertOp) {
475
+ return !insertOp.getType ().isScalable ();
476
+ })
477
+ .Case <vector::ExtractOp>([&](vector::ExtractOp extractOp) {
478
+ return !extractOp.getSourceVectorType ().isScalable ();
479
+ })
480
+ .Default ([&](auto ) { return true ; });
500
481
}
501
482
502
483
void mlir::vector::populateForVectorLinearize (TypeConverter &typeConverter,
@@ -530,7 +511,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
530
511
531
512
target.markUnknownOpDynamicallyLegal (
532
513
[=](Operation *op) -> std::optional<bool > {
533
- if (isNotLinearizable (op))
514
+ if (! isLinearizable (op))
534
515
return true ;
535
516
// This will return true if, for all operand and result types `t`,
536
517
// convertType(t) = t. This is true if there are no rank>=2 vectors.
0 commit comments