Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mlir][py] invalidate nested operations when parent is deleted #93339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions 35 mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,17 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
clearOperationsInside(opRef->getOperation());
}

void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
contextRef->clearOperation(op);
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
&op.getOperation().getContext(), MlirWalkPreOrder);
}

size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }

pybind11::object PyMlirContext::contextEnter() {
Expand Down Expand Up @@ -1112,12 +1123,16 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
auto &liveOperations = getContext()->liveOperations;
assert(liveOperations.count(operation.ptr) == 1 &&
"destroying operation not in live map");
liveOperations.erase(operation.ptr);
if (!isAttached()) {
mlirOperationDestroy(operation);

// Otherwise, invalidate the operation and remove it from live map when it is
// attached.
if (isAttached()) {
getContext()->clearOperation(*this);
} else {
// And destroy it when it is detached, i.e. owned by Python, in which case
// all nested operations must be invalidated at removed from the live map as
// well.
erase();
}
}

Expand Down Expand Up @@ -1527,14 +1542,8 @@ py::object PyOperation::createOpView() {

void PyOperation::erase() {
checkValid();
// TODO: Fix memory hazards when erasing a tree of operations for which a deep
// Python reference to a child operation is live. All children should also
// have their `valid` bit set to false.
auto &liveOperations = getContext()->liveOperations;
if (liveOperations.count(operation.ptr))
liveOperations.erase(operation.ptr);
getContext()->clearOperationAndInside(*this);
mlirOperationDestroy(operation);
valid = false;
}

//------------------------------------------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions 7 mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,19 @@ class PyMlirContext {
/// This is useful for when some non-bindings code destroys the operation and
/// the bindings need to made aware. For example, in the case when pass
/// manager is run.
///
/// Note that this does *NOT* clear the nested operations.
void clearOperation(MlirOperation op);

/// Clears all operations nested inside the given op using
/// `clearOperation(MlirOperation)`.
void clearOperationsInside(PyOperationBase &op);
void clearOperationsInside(MlirOperation op);

/// Clears the operaiton _and_ all operations inside using
/// `clearOperation(MlirOperation)`.
void clearOperationAndInside(PyOperationBase &op);

/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();
Expand All @@ -246,6 +252,7 @@ class PyMlirContext {

private:
PyMlirContext(MlirContext context);

// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
Expand Down
46 changes: 46 additions & 0 deletions 46 mlir/test/python/live_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# RUN: %PYTHON %s
# It is sufficient that this doesn't assert.

from mlir.ir import *


def createDetachedModule():
module = Module.create()
with InsertionPoint(module.body):
# TODO: Python bindings are currently unaware that modules are also
# operations, so having a module erased won't trigger the cascading
# removal of live operations (#93337). Use a non-module operation
# instead.
nested = Operation.create("test.some_operation", regions=1)

# When the operation is detached from parent, it is considered to be
# owned by Python. It will therefore be erased when the Python object
# is destroyed.
nested.detach_from_parent()

# However, we create and maintain references to operations within
# `nested`. These references keep the corresponding operations in the
# "live" list even if they have been erased in C++, making them
# "zombie". If the C++ allocator reuses one of the address previously
# used for a now-"zombie" operation, this used to result in an
# assertion "cannot create detached operation that already exists" from
# the bindings code. Erasing the detached operation should result in
# removing all nested operations from the live list.
#
# Note that the assertion is not guaranteed since it depends on the
# behavior of the allocator on the C++ side, so this test mail fail
# intermittently.
with InsertionPoint(nested.regions[0].blocks.append()):
a = [Operation.create("test.some_other_operation") for i in range(100)]
return a


def createManyDetachedModules():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
for j in range(100):
a = createDetachedModule()


if __name__ == "__main__":
createManyDetachedModules()
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.