mlir: enable mincut support for scf.for#2359
Conversation
| if (updatedGradients.empty() && caches.empty()) | ||
| return success(); | ||
|
|
||
| if (forOp->hasAttr("enzyme.enable_mincut")) { |
There was a problem hiding this comment.
i think we should make this the default [and optionally disable]
There was a problem hiding this comment.
we'll also want to add this to affine.for/{scf,affine}.parallel
There was a problem hiding this comment.
makes sense, i will follow-up with those (i don't think we have the rev rules for these ops yet)
| // CHECK-NEXT: %c0 = arith.constant 0 : index | ||
| // CHECK-NEXT: %[[v0:.+]] = tensor.empty() : tensor<10xf32> | ||
| // CHECK-NEXT: %[[for:.+]]:2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %arg0, %arg4 = %[[v0]]) -> (f32, tensor<10xf32>) { | ||
| // CHECK-NEXT: %[[cache:.+]] = tensor.insert %arg3 into %arg4[%arg2] : tensor<10xf32> |
There was a problem hiding this comment.
we should likely prefer using a memref alloca of the correct size, store in fwd, load in reverse, and free after the for
There was a problem hiding this comment.
it will create a cache<memref> for the parent op so the alloca has to be hoisted all the way to the top (and would be invalid for upcoming autodiff deferred). IIRC, that's why I went with an immutable data structure at the time.
We can write a bufferization pass that converts this to memref once all push/pop have been removed ?
|
@Pangoraw I'm going to merge for now, can you address in post commit follow up |
No description provided.