@@ -3257,178 +3257,6 @@ class AIREnforceLoopCarriedMemrefDeallocPattern
32573257private:
32583258};
32593259
3260- // A pass which de-alias a memref with multiple channel accesses over time, into
3261- // multiple memrefs. Note that this implementation is temporary and not generic.
3262- // TODO: Rewrite as a graph partitioning problem.
3263- class AIRDeAliasMemref
3264- : public xilinx::air::impl::AIRDeAliasMemrefBase<AIRDeAliasMemref> {
3265-
3266- public:
3267- AIRDeAliasMemref () = default ;
3268- AIRDeAliasMemref (const AIRDeAliasMemref &pass) {}
3269-
3270- void getDependentDialects (::mlir::DialectRegistry ®istry) const override {
3271- registry.insert <scf::SCFDialect, air::airDialect>();
3272- }
3273-
3274- void runOnFunction (func::FuncOp f) {
3275-
3276- std::vector<memref::AllocOp> allocs;
3277- f.walk ([&](memref::AllocOp alloc) { allocs.push_back (alloc); });
3278-
3279- // Count air.channel references
3280- for (auto alloc : allocs) {
3281- Value memref = nullptr ;
3282- if (auto exec = alloc->getParentOfType <air::ExecuteOp>()) {
3283- memref = exec->getResult (1 );
3284- } else
3285- memref = alloc.getMemref ();
3286- std::vector<air::ChannelInterface> chan_puts_gets;
3287- for (auto user : memref.getUsers ()) {
3288- if (auto putget = dyn_cast<air::ChannelInterface>(user))
3289- if (putget.getMemref () == memref)
3290- chan_puts_gets.push_back (putget);
3291- }
3292-
3293- // Partition the subgraph
3294- std::vector<int > partition_cuts;
3295- if (!chan_puts_gets.empty ()) {
3296- for (unsigned i = 0 ; i < chan_puts_gets.size () - 1 ; i++) {
3297- if (isa<air::ChannelGetOp>(chan_puts_gets[i].getOperation ()) &&
3298- isa<air::ChannelPutOp>(chan_puts_gets[i + 1 ].getOperation ())) {
3299- partition_cuts.push_back (i + 1 );
3300- }
3301- }
3302- }
3303-
3304- // Allocate new memref per cut
3305- std::vector<Operation *> new_memallocs;
3306- for (unsigned i = 0 ; i < partition_cuts.size (); i++) {
3307- OpBuilder builder (alloc);
3308- Operation *new_op = nullptr ;
3309- if (auto exec = alloc->getParentOfType <air::ExecuteOp>()) {
3310- builder.setInsertionPoint (exec);
3311- new_op = builder.clone (*exec.getOperation ());
3312- } else
3313- new_op = builder.clone (*alloc.getOperation ());
3314- new_memallocs.push_back (new_op);
3315-
3316- // Create deallocs for the new memref
3317- Value new_memref = isa<air::ExecuteOp>(new_op) ? new_op->getResult (1 )
3318- : new_op->getResult (0 );
3319- for (auto user : memref.getUsers ()) {
3320- if (isa<memref::DeallocOp>(user)) {
3321- if (isa<air::ExecuteOp>(new_op)) {
3322- builder.setInsertionPoint (
3323- user->getParentOfType <air::ExecuteOp>());
3324- // Async. dealloc
3325- auto async_exec = builder.create <xilinx::air::ExecuteOp>(
3326- user->getLoc (), air::AsyncTokenType::get (alloc->getContext ()),
3327- SmallVector<Value>{});
3328- Block *async_exec_bb =
3329- builder.createBlock (&async_exec.getRegion ());
3330- builder.setInsertionPointToStart (async_exec_bb);
3331- builder.create <memref::DeallocOp>(user->getLoc (), new_memref);
3332- builder.create <air::ExecuteTerminatorOp>(user->getLoc ());
3333- } else {
3334- builder.setInsertionPoint (user);
3335- // Sync. dealloc
3336- builder.create <memref::DeallocOp>(user->getLoc (), new_memref);
3337- }
3338- }
3339- }
3340- }
3341-
3342- // Update references
3343- partition_cuts.insert (partition_cuts.end (), chan_puts_gets.size ());
3344- for (unsigned i = 0 ; i < partition_cuts.size () - 1 ; i++) {
3345- for (int j = partition_cuts[i]; j < partition_cuts[i + 1 ]; j++) {
3346- if (auto old_put = dyn_cast<air::ChannelPutOp>(
3347- chan_puts_gets[j].getOperation ())) {
3348- Value new_memref = isa<air::ExecuteOp>(new_memallocs[i])
3349- ? new_memallocs[i]->getResult (1 )
3350- : new_memallocs[i]->getResult (0 );
3351- OpBuilder builder (old_put);
3352- replaceChannelPutOp (builder, old_put, new_memref);
3353- } else if (auto old_get = dyn_cast<air::ChannelGetOp>(
3354- chan_puts_gets[j].getOperation ())) {
3355- Value new_memref = isa<air::ExecuteOp>(new_memallocs[i])
3356- ? new_memallocs[i]->getResult (1 )
3357- : new_memallocs[i]->getResult (0 );
3358- OpBuilder builder (old_get);
3359- replaceChannelGetOp (builder, old_get, new_memref);
3360- }
3361- }
3362- }
3363- }
3364- }
3365-
3366- void runOnOperation () override {
3367- auto module = getOperation ();
3368-
3369- SmallVector<func::FuncOp, 4 > funcOps;
3370- module .walk ([&](func::FuncOp op) { funcOps.push_back (op); });
3371- for (auto f : funcOps) {
3372- runOnFunction (f);
3373- }
3374- }
3375-
3376- private:
3377- Operation *replaceChannelPutOp (OpBuilder builder, air::ChannelPutOp old,
3378- Value new_memref) {
3379- builder.setInsertionPoint (old);
3380- SmallVector<Type, 1 > tys;
3381- if (old.getAsyncToken ()) {
3382- tys.push_back (air::AsyncTokenType::get (old->getContext ()));
3383- }
3384- SmallVector<Value, 4 > deps = old.getAsyncDependencies ();
3385- auto new_op = builder.create <air::ChannelPutOp>(
3386- old->getLoc (), tys, deps, old.getChanName (), old.getIndices (),
3387- new_memref, old.getSrcOffsets (), old.getSrcSizes (),
3388- old.getSrcStrides ());
3389- if (old.getAsyncToken ()) {
3390- old.getAsyncToken ().replaceAllUsesWith (new_op.getAsyncToken ());
3391- // Add dependence to the new memref
3392- new_op.addAsyncDependency (
3393- dyn_cast<air::ExecuteOp>(new_memref.getDefiningOp ()).getAsyncToken ());
3394- }
3395- if (old.getId () != -1 ) {
3396- new_op->setAttr (" id" , mlir::IntegerAttr::get (
3397- mlir::IntegerType::get (old->getContext (), 32 ),
3398- old.getId ()));
3399- }
3400- old->erase ();
3401- return new_op.getOperation ();
3402- }
3403- Operation *replaceChannelGetOp (OpBuilder builder, air::ChannelGetOp old,
3404- Value new_memref) {
3405- builder.setInsertionPoint (old);
3406- SmallVector<Type, 1 > tys;
3407- if (old.getAsyncToken ()) {
3408- tys.push_back (air::AsyncTokenType::get (old->getContext ()));
3409- }
3410- SmallVector<Value, 4 > deps = old.getAsyncDependencies ();
3411- auto new_op = builder.create <air::ChannelGetOp>(
3412- old->getLoc (), tys, deps, old.getChanName (), old.getIndices (),
3413- new_memref, old.getDstOffsets (), old.getDstSizes (),
3414- old.getDstStrides ());
3415- new_op->setAttrs (old->getDiscardableAttrDictionary ());
3416- if (old.getAsyncToken ()) {
3417- old.getAsyncToken ().replaceAllUsesWith (new_op.getAsyncToken ());
3418- // Add dependence to the new memref
3419- new_op.addAsyncDependency (
3420- dyn_cast<air::ExecuteOp>(new_memref.getDefiningOp ()).getAsyncToken ());
3421- }
3422- if (old.getId () != -1 ) {
3423- new_op->setAttr (" id" , mlir::IntegerAttr::get (
3424- mlir::IntegerType::get (old->getContext (), 32 ),
3425- old.getId ()));
3426- }
3427- old->erase ();
3428- return new_op.getOperation ();
3429- }
3430- };
3431-
34323260// A pass which transform multiple channel ops into one, where the data movement
34333261// is time-multiplexed.
34343262class AIRFuseChannels
@@ -6175,10 +6003,6 @@ std::unique_ptr<Pass> createAIREnforceLoopCarriedMemrefDeallocPattern() {
61756003 return std::make_unique<AIREnforceLoopCarriedMemrefDeallocPattern>();
61766004}
61776005
6178- std::unique_ptr<Pass> createAIRDeAliasMemref () {
6179- return std::make_unique<AIRDeAliasMemref>();
6180- }
6181-
61826006std::unique_ptr<Pass> createAIRFuseChannels () {
61836007 return std::make_unique<AIRFuseChannels>();
61846008}
0 commit comments