Skip to content

Commit

Permalink
[Python] fixed subscription crash (#32257)
Browse files Browse the repository at this point in the history
* [Python] After SubscriptionTransaction has an error, calling Shutdown() will crash

* Add a comment

* update comment
  • Loading branch information
tianfeng-yang authored Oct 14, 2024
1 parent 4df081c commit dcb4444
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 55 deletions.
126 changes: 80 additions & 46 deletions src/controller/python/chip/clusters/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def __post_init__(self):
'''Only one of either ClusterType and AttributeType OR Path may be provided.'''

if (self.ClusterType is not None and self.AttributeType is not None) and self.Path is not None:
raise ValueError("Only one of either ClusterType and AttributeType OR Path may be provided.")
raise ValueError(
"Only one of either ClusterType and AttributeType OR Path may be provided.")
if (self.ClusterType is None or self.AttributeType is None) and self.Path is None:
raise ValueError("Either ClusterType and AttributeType OR Path must be provided.")
raise ValueError(
"Either ClusterType and AttributeType OR Path must be provided.")

# if ClusterType and AttributeType were provided we can continue onwards to deriving the label.
# Otherwise, we'll need to walk the attribute index to find the right type information.
Expand Down Expand Up @@ -373,7 +375,8 @@ def handle_cluster_view(endpointId, clusterId, clusterType):
try:
decodedData = clusterType.FromDict(
data=clusterType.descriptor.TagDictToLabelDict([], self.attributeTLVCache[endpointId][clusterId]))
decodedData.SetDataVersion(self.versionList.get(endpointId, {}).get(clusterId))
decodedData.SetDataVersion(
self.versionList.get(endpointId, {}).get(clusterId))
return decodedData
except Exception as ex:
return ValueDecodeFailure(self.attributeTLVCache[endpointId][clusterId], ex)
Expand Down Expand Up @@ -404,12 +407,14 @@ def handle_attribute_view(endpointId, clusterId, attributeId, attributeType):
clusterType = _ClusterIndex[clusterId]

if self.returnClusterObject:
endpointCache[clusterType] = handle_cluster_view(endpointId, clusterId, clusterType)
endpointCache[clusterType] = handle_cluster_view(
endpointId, clusterId, clusterType)
else:
if clusterType not in endpointCache:
endpointCache[clusterType] = {}
clusterCache = endpointCache[clusterType]
clusterCache[DataVersion] = self.versionList.get(endpointId, {}).get(clusterId)
clusterCache[DataVersion] = self.versionList.get(
endpointId, {}).get(clusterId)

if (clusterId, attributeId) not in _AttributeIndex:
#
Expand All @@ -419,7 +424,8 @@ def handle_attribute_view(endpointId, clusterId, attributeId, attributeType):
continue

attributeType = _AttributeIndex[(clusterId, attributeId)][0]
clusterCache[attributeType] = handle_attribute_view(endpointId, clusterId, attributeId, attributeType)
clusterCache[attributeType] = handle_attribute_view(
endpointId, clusterId, attributeId, attributeType)
self._attributeCacheUpdateNeeded.clear()
return self._attributeCache

Expand All @@ -428,14 +434,18 @@ class SubscriptionTransaction:
def __init__(self, transaction: AsyncReadTransaction, subscriptionId, devCtrl):
self._onResubscriptionAttemptedCb: Callable[[SubscriptionTransaction,
int, int], None] = DefaultResubscriptionAttemptedCallback
self._onAttributeChangeCb: Callable[[TypedAttributePath, SubscriptionTransaction], None] = DefaultAttributeChangeCallback
self._onEventChangeCb: Callable[[EventReadResult, SubscriptionTransaction], None] = DefaultEventChangeCallback
self._onErrorCb: Callable[[int, SubscriptionTransaction], None] = DefaultErrorCallback
self._onAttributeChangeCb: Callable[[
TypedAttributePath, SubscriptionTransaction], None] = DefaultAttributeChangeCallback
self._onEventChangeCb: Callable[[
EventReadResult, SubscriptionTransaction], None] = DefaultEventChangeCallback
self._onErrorCb: Callable[[
int, SubscriptionTransaction], None] = DefaultErrorCallback
self._readTransaction = transaction
self._subscriptionId = subscriptionId
self._devCtrl = devCtrl
self._isDone = False
self._onResubscriptionSucceededCb: Optional[Callable[[SubscriptionTransaction], None]] = None
self._onResubscriptionSucceededCb: Optional[Callable[[
SubscriptionTransaction], None]] = None
self._onResubscriptionSucceededCb_isAsync = False
self._onResubscriptionAttemptedCb_isAsync = False

Expand All @@ -460,7 +470,8 @@ def GetEvents(self):
def OverrideLivenessTimeoutMs(self, timeoutMs: int):
handle = chip.native.GetLibraryHandle()
builtins.chipStack.Call(
lambda: handle.pychip_ReadClient_OverrideLivenessTimeout(self._readTransaction._pReadClient, timeoutMs)
lambda: handle.pychip_ReadClient_OverrideLivenessTimeout(
self._readTransaction._pReadClient, timeoutMs)
)

async def TriggerResubscribeIfScheduled(self, reason: str):
Expand Down Expand Up @@ -501,7 +512,8 @@ def GetSubscriptionTimeoutMs(self) -> int:
timeoutMs = ctypes.c_uint32(0)
handle = chip.native.GetLibraryHandle()
builtins.chipStack.Call(
lambda: handle.pychip_ReadClient_GetSubscriptionTimeoutMs(self._readTransaction._pReadClient, ctypes.pointer(timeoutMs))
lambda: handle.pychip_ReadClient_GetSubscriptionTimeoutMs(
self._readTransaction._pReadClient, ctypes.pointer(timeoutMs))
)
return timeoutMs.value

Expand Down Expand Up @@ -567,13 +579,14 @@ def subscriptionId(self) -> int:

def Shutdown(self):
if (self._isDone):
LOGGER.warning("Subscription 0x%08x was already terminated previously!", self.subscriptionId)
LOGGER.warning(
"Subscription 0x%08x was already terminated previously!", self.subscriptionId)
return

handle = chip.native.GetLibraryHandle()
builtins.chipStack.Call(
lambda: handle.pychip_ReadClient_Abort(
self._readTransaction._pReadClient, self._readTransaction._pReadCallback))
lambda: handle.pychip_ReadClient_ShutdownSubscription(
self._readTransaction._pReadClient))
self._isDone = True

def __del__(self):
Expand All @@ -585,7 +598,8 @@ def __repr__(self):

def DefaultResubscriptionAttemptedCallback(transaction: SubscriptionTransaction,
terminationError, nextResubscribeIntervalMsec):
print(f"Previous subscription failed with Error: {terminationError} - re-subscribing in {nextResubscribeIntervalMsec}ms...")
print(
f"Previous subscription failed with Error: {terminationError} - re-subscribing in {nextResubscribeIntervalMsec}ms...")


def DefaultAttributeChangeCallback(path: TypedAttributePath, transaction: SubscriptionTransaction):
Expand Down Expand Up @@ -648,12 +662,10 @@ def __init__(self, future: Future, eventLoop, devCtrl, returnClusterObject: bool
self._cache = AttributeCache(returnClusterObject=returnClusterObject)
self._changedPathSet: Set[AttributePath] = set()
self._pReadClient = None
self._pReadCallback = None
self._resultError: Optional[PyChipError] = None

def SetClientObjPointers(self, pReadClient, pReadCallback):
def SetClientObjPointers(self, pReadClient):
self._pReadClient = pReadClient
self._pReadCallback = pReadCallback

def GetAllEventValues(self):
return self._events
Expand Down Expand Up @@ -729,7 +741,8 @@ def handleEventData(self, header: EventHeader, path: EventPath, data: bytes, sta

def handleError(self, chipError: PyChipError):
if self._subscription_handler:
self._subscription_handler.OnErrorCb(chipError.code, self._subscription_handler)
self._subscription_handler.OnErrorCb(
chipError.code, self._subscription_handler)
self._resultError = chipError

def _handleSubscriptionEstablished(self, subscriptionId):
Expand All @@ -744,7 +757,8 @@ def _handleSubscriptionEstablished(self, subscriptionId):
self._event_loop.create_task(
self._subscription_handler._onResubscriptionSucceededCb(self._subscription_handler))
else:
self._subscription_handler._onResubscriptionSucceededCb(self._subscription_handler)
self._subscription_handler._onResubscriptionSucceededCb(
self._subscription_handler)

def handleSubscriptionEstablished(self, subscriptionId):
self._event_loop.call_soon_threadsafe(
Expand Down Expand Up @@ -820,7 +834,8 @@ def __init__(self, future: Future, eventLoop):
def handleResponse(self, path: AttributePath, status: int):
try:
imStatus = chip.interaction_model.Status(status)
self._resultData.append(AttributeWriteResult(Path=path, Status=imStatus))
self._resultData.append(
AttributeWriteResult(Path=path, Status=imStatus))
except ValueError as ex:
LOGGER.exception(ex)

Expand All @@ -835,8 +850,10 @@ def _handleDone(self):
#
if self._resultError is not None:
if self._resultError.sdk_part is ErrorSDKPart.IM_GLOBAL_STATUS:
im_status = chip.interaction_model.Status(self._resultError.sdk_code)
self._future.set_exception(chip.interaction_model.InteractionModelError(im_status))
im_status = chip.interaction_model.Status(
self._resultError.sdk_code)
self._future.set_exception(
chip.interaction_model.InteractionModelError(im_status))
else:
self._future.set_exception(self._resultError.to_exception())
else:
Expand All @@ -856,7 +873,8 @@ def handleDone(self):
_OnReadAttributeDataCallbackFunct = CFUNCTYPE(
None, py_object, c_uint32, c_uint16, c_uint32, c_uint32, c_uint8, c_void_p, c_size_t)
_OnSubscriptionEstablishedCallbackFunct = CFUNCTYPE(None, py_object, c_uint32)
_OnResubscriptionAttemptedCallbackFunct = CFUNCTYPE(None, py_object, PyChipError, c_uint32)
_OnResubscriptionAttemptedCallbackFunct = CFUNCTYPE(
None, py_object, PyChipError, c_uint32)
_OnReadEventDataCallbackFunct = CFUNCTYPE(
None, py_object, c_uint16, c_uint32, c_uint32, c_uint64, c_uint8, c_uint64, c_uint8, c_void_p, c_size_t, c_uint8)
_OnReadErrorCallbackFunct = CFUNCTYPE(
Expand Down Expand Up @@ -897,7 +915,8 @@ def _OnSubscriptionEstablishedCallback(closure, subscriptionId):

@_OnResubscriptionAttemptedCallbackFunct
def _OnResubscriptionAttemptedCallback(closure, terminationCause: PyChipError, nextResubscribeIntervalMsec: int):
closure.handleResubscriptionAttempted(terminationCause, nextResubscribeIntervalMsec)
closure.handleResubscriptionAttempted(
terminationCause, nextResubscribeIntervalMsec)


@_OnReadErrorCallbackFunct
Expand Down Expand Up @@ -954,25 +973,34 @@ def WriteAttributes(future: Future, eventLoop, device,
pyWriteAttributes = pyWriteAttributesArrayType()
for idx, attr in enumerate(attributes):
if attr.Attribute.must_use_timed_write and timedRequestTimeoutMs is None or timedRequestTimeoutMs == 0:
raise chip.interaction_model.InteractionModelError(chip.interaction_model.Status.NeedsTimedInteraction)
raise chip.interaction_model.InteractionModelError(
chip.interaction_model.Status.NeedsTimedInteraction)

tlv = attr.Attribute.ToTLV(None, attr.Data)

pyWriteAttributes[idx].attributePath.endpointId = c_uint16(attr.EndpointId)
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(attr.Attribute.cluster_id)
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(attr.Attribute.attribute_id)
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(attr.DataVersion)
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(attr.HasDataVersion)
pyWriteAttributes[idx].tlvData = cast(ctypes.c_char_p(bytes(tlv)), c_void_p)
pyWriteAttributes[idx].attributePath.endpointId = c_uint16(
attr.EndpointId)
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(
attr.Attribute.cluster_id)
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(
attr.Attribute.attribute_id)
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(
attr.DataVersion)
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(
attr.HasDataVersion)
pyWriteAttributes[idx].tlvData = cast(
ctypes.c_char_p(bytes(tlv)), c_void_p)
pyWriteAttributes[idx].tlvLength = c_size_t(len(tlv))

transaction = AsyncWriteTransaction(future, eventLoop)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(transaction))
res = builtins.chipStack.Call(
lambda: handle.pychip_WriteClient_WriteAttributes(
ctypes.py_object(transaction), device,
ctypes.c_size_t(0 if timedRequestTimeoutMs is None else timedRequestTimeoutMs),
ctypes.c_size_t(0 if interactionTimeoutMs is None else interactionTimeoutMs),
ctypes.c_size_t(
0 if timedRequestTimeoutMs is None else timedRequestTimeoutMs),
ctypes.c_size_t(
0 if interactionTimeoutMs is None else interactionTimeoutMs),
ctypes.c_size_t(0 if busyWaitMs is None else busyWaitMs),
pyWriteAttributes, ctypes.c_size_t(numberOfAttributes))
)
Expand All @@ -991,12 +1019,18 @@ def WriteGroupAttributes(groupId: int, devCtrl: c_void_p, attributes: List[Attri

tlv = attr.Attribute.ToTLV(None, attr.Data)

pyWriteAttributes[idx].attributePath.endpointId = c_uint16(attr.EndpointId)
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(attr.Attribute.cluster_id)
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(attr.Attribute.attribute_id)
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(attr.DataVersion)
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(attr.HasDataVersion)
pyWriteAttributes[idx].tlvData = cast(ctypes.c_char_p(bytes(tlv)), c_void_p)
pyWriteAttributes[idx].attributePath.endpointId = c_uint16(
attr.EndpointId)
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(
attr.Attribute.cluster_id)
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(
attr.Attribute.attribute_id)
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(
attr.DataVersion)
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(
attr.HasDataVersion)
pyWriteAttributes[idx].tlvData = cast(
ctypes.c_char_p(bytes(tlv)), c_void_p)
pyWriteAttributes[idx].tlvLength = c_size_t(len(tlv))

return builtins.chipStack.Call(
Expand Down Expand Up @@ -1071,7 +1105,8 @@ def Read(transaction: AsyncReadTransaction, device,
"DataVersionFilter must provide DataVersion.")
filter = chip.interaction_model.DataVersionFilterIBstruct.build(
filter)
dataVersionFiltersForCffi[idx] = cast(ctypes.c_char_p(filter), c_void_p)
dataVersionFiltersForCffi[idx] = cast(
ctypes.c_char_p(filter), c_void_p)

eventPathsForCffi = None
if events is not None:
Expand All @@ -1095,7 +1130,6 @@ def Read(transaction: AsyncReadTransaction, device,
eventPathsForCffi[idx] = cast(ctypes.c_char_p(path), c_void_p)

readClientObj = ctypes.POINTER(c_void_p)()
readCallbackObj = ctypes.POINTER(c_void_p)()

ctypes.pythonapi.Py_IncRef(ctypes.py_object(transaction))
params = _ReadParams.parse(b'\x00' * _ReadParams.sizeof())
Expand All @@ -1109,13 +1143,13 @@ def Read(transaction: AsyncReadTransaction, device,
params = _ReadParams.build(params)
eventNumberFilterPtr = ctypes.POINTER(ctypes.c_ulonglong)()
if eventNumberFilter is not None:
eventNumberFilterPtr = ctypes.POINTER(ctypes.c_ulonglong)(ctypes.c_ulonglong(eventNumberFilter))
eventNumberFilterPtr = ctypes.POINTER(ctypes.c_ulonglong)(
ctypes.c_ulonglong(eventNumberFilter))

res = builtins.chipStack.Call(
lambda: handle.pychip_ReadClient_Read(
ctypes.py_object(transaction),
ctypes.byref(readClientObj),
ctypes.byref(readCallbackObj),
device,
ctypes.c_char_p(params),
attributePathsForCffi,
Expand All @@ -1127,7 +1161,7 @@ def Read(transaction: AsyncReadTransaction, device,
ctypes.c_size_t(0 if events is None else len(events)),
eventNumberFilterPtr))

transaction.SetClientObjPointers(readClientObj, readCallbackObj)
transaction.SetClientObjPointers(readClientObj)

if not res.is_success:
ctypes.pythonapi.Py_DecRef(ctypes.py_object(transaction))
Expand Down
25 changes: 16 additions & 9 deletions src/controller/python/chip/clusters/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,20 @@ PyChipError pychip_WriteClient_WriteGroupAttributes(size_t groupIdSizeT, chip::C
return ToPyChipError(err);
}

void pychip_ReadClient_Abort(ReadClient * apReadClient, ReadClientCallback * apCallback)
void pychip_ReadClient_ShutdownSubscription(ReadClient * apReadClient)
{
VerifyOrDie(apReadClient != nullptr);
VerifyOrDie(apCallback != nullptr);
// If apReadClient is nullptr, it means that its life cycle has ended (such as an error happend), and nothing needs to be done.
VerifyOrReturn(apReadClient != nullptr);
// If it is not SubscriptionType, this function should not be executed.
VerifyOrDie(apReadClient->IsSubscriptionType());

delete apCallback;
Optional<SubscriptionId> subscriptionId = apReadClient->GetSubscriptionId();
VerifyOrDie(subscriptionId.HasValue());

FabricIndex fabricIndex = apReadClient->GetFabricIndex();
NodeId nodeId = apReadClient->GetPeerNodeId();

InteractionModelEngine::GetInstance()->ShutdownSubscription(ScopedNodeId(nodeId, fabricIndex), subscriptionId.Value());
}

void pychip_ReadClient_OverrideLivenessTimeout(ReadClient * pReadClient, uint32_t livenessTimeoutMs)
Expand Down Expand Up @@ -497,10 +505,10 @@ void pychip_ReadClient_GetSubscriptionTimeoutMs(ReadClient * pReadClient, uint32
}
}

PyChipError pychip_ReadClient_Read(void * appContext, ReadClient ** pReadClient, ReadClientCallback ** pCallback,
DeviceProxy * device, uint8_t * readParamsBuf, void ** attributePathsFromPython,
size_t numAttributePaths, void ** dataversionFiltersFromPython, size_t numDataversionFilters,
void ** eventPathsFromPython, size_t numEventPaths, uint64_t * eventNumberFilter)
PyChipError pychip_ReadClient_Read(void * appContext, ReadClient ** pReadClient, DeviceProxy * device, uint8_t * readParamsBuf,
void ** attributePathsFromPython, size_t numAttributePaths, void ** dataversionFiltersFromPython,
size_t numDataversionFilters, void ** eventPathsFromPython, size_t numEventPaths,
uint64_t * eventNumberFilter)
{
CHIP_ERROR err = CHIP_NO_ERROR;
PyReadAttributeParams pyParams = {};
Expand Down Expand Up @@ -612,7 +620,6 @@ PyChipError pychip_ReadClient_Read(void * appContext, ReadClient ** pReadClient,
}

*pReadClient = readClient.get();
*pCallback = callback.get();

callback->AdoptReadClient(std::move(readClient));

Expand Down

0 comments on commit dcb4444

Please sign in to comment.