[netdiag] fix setting DIAG_GET response callback (#6192)
This commit sets the DIAG_GET response callback every time the
otThreadSendDiagnosticGet is called. This is because the callback may
be changed to other modules (e,g. OTBR_REST, OTBR_UBUS).
diff --git a/include/openthread/instance.h b/include/openthread/instance.h
index adb5502..186edd0 100644
--- a/include/openthread/instance.h
+++ b/include/openthread/instance.h
@@ -53,7 +53,7 @@
* @note This number versions both OpenThread platform and user APIs.
*
*/
-#define OPENTHREAD_API_VERSION (78)
+#define OPENTHREAD_API_VERSION (79)
/**
* @addtogroup api-instance
diff --git a/include/openthread/netdiag.h b/include/openthread/netdiag.h
index 376780f..2d291a0 100644
--- a/include/openthread/netdiag.h
+++ b/include/openthread/netdiag.h
@@ -297,34 +297,26 @@
void * aContext);
/**
- * This function registers a callback to provide received raw Network Diagnostic Get response payload.
+ * Send a Network Diagnostic Get request.
*
* @param[in] aInstance A pointer to an OpenThread instance.
+ * @param[in] aDestination A pointer to destination address.
+ * @param[in] aTlvTypes An array of Network Diagnostic TLV types.
+ * @param[in] aCount Number of types in aTlvTypes.
* @param[in] aCallback A pointer to a function that is called when Network Diagnostic Get response
* is received or NULL to disable the callback.
* @param[in] aCallbackContext A pointer to application-specific context.
*
- */
-void otThreadSetReceiveDiagnosticGetCallback(otInstance * aInstance,
- otReceiveDiagnosticGetCallback aCallback,
- void * aCallbackContext);
-
-/**
- * Send a Network Diagnostic Get request.
- *
- * @param[in] aInstance A pointer to an OpenThread instance.
- * @param[in] aDestination A pointer to destination address.
- * @param[in] aTlvTypes An array of Network Diagnostic TLV types.
- * @param[in] aCount Number of types in aTlvTypes.
- *
* @retval OT_ERROR_NONE Successfully queued the DIAG_GET.req.
* @retval OT_ERROR_NO_BUFS Insufficient message buffers available to send DIAG_GET.req.
*
*/
-otError otThreadSendDiagnosticGet(otInstance * aInstance,
- const otIp6Address *aDestination,
- const uint8_t aTlvTypes[],
- uint8_t aCount);
+otError otThreadSendDiagnosticGet(otInstance * aInstance,
+ const otIp6Address * aDestination,
+ const uint8_t aTlvTypes[],
+ uint8_t aCount,
+ otReceiveDiagnosticGetCallback aCallback,
+ void * aCallbackContext);
/**
* Send a Network Diagnostic Reset request.
diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp
index f7dcbb9..16eecdc 100644
--- a/src/cli/cli.cpp
+++ b/src/cli/cli.cpp
@@ -147,9 +147,6 @@
, mSrpServer(*this)
#endif
{
-#if OPENTHREAD_FTD || OPENTHREAD_CONFIG_TMF_NETWORK_DIAG_MTD_ENABLE
- otThreadSetReceiveDiagnosticGetCallback(mInstance, &Interpreter::HandleDiagnosticGetResponse, this);
-#endif
#if OPENTHREAD_FTD
otThreadSetDiscoveryRequestCallback(mInstance, &Interpreter::HandleDiscoveryRequest, this);
#endif
@@ -4850,7 +4847,8 @@
if (strcmp(aArgs[0], "get") == 0)
{
- IgnoreError(otThreadSendDiagnosticGet(mInstance, &address, tlvTypes, count));
+ SuccessOrExit(error = otThreadSendDiagnosticGet(mInstance, &address, tlvTypes, count,
+ &Interpreter::HandleDiagnosticGetResponse, this));
ExitNow(error = OT_ERROR_PENDING);
}
else if (strcmp(aArgs[0], "reset") == 0)
diff --git a/src/core/api/netdiag_api.cpp b/src/core/api/netdiag_api.cpp
index 5b75e0e..3e8b836 100644
--- a/src/core/api/netdiag_api.cpp
+++ b/src/core/api/netdiag_api.cpp
@@ -49,24 +49,17 @@
*aIterator, *aNetworkDiagTlv);
}
-void otThreadSetReceiveDiagnosticGetCallback(otInstance * aInstance,
- otReceiveDiagnosticGetCallback aCallback,
- void * aCallbackContext)
-{
- Instance &instance = *static_cast<Instance *>(aInstance);
-
- instance.Get<NetworkDiagnostic::NetworkDiagnostic>().SetReceiveDiagnosticGetCallback(aCallback, aCallbackContext);
-}
-
-otError otThreadSendDiagnosticGet(otInstance * aInstance,
- const otIp6Address *aDestination,
- const uint8_t aTlvTypes[],
- uint8_t aCount)
+otError otThreadSendDiagnosticGet(otInstance * aInstance,
+ const otIp6Address * aDestination,
+ const uint8_t aTlvTypes[],
+ uint8_t aCount,
+ otReceiveDiagnosticGetCallback aCallback,
+ void * aCallbackContext)
{
Instance &instance = *static_cast<Instance *>(aInstance);
return instance.Get<NetworkDiagnostic::NetworkDiagnostic>().SendDiagnosticGet(
- *static_cast<const Ip6::Address *>(aDestination), aTlvTypes, aCount);
+ *static_cast<const Ip6::Address *>(aDestination), aTlvTypes, aCount, aCallback, aCallbackContext);
}
otError otThreadSendDiagnosticReset(otInstance * aInstance,
diff --git a/src/core/thread/network_diagnostic.cpp b/src/core/thread/network_diagnostic.cpp
index a26525b..0365549 100644
--- a/src/core/thread/network_diagnostic.cpp
+++ b/src/core/thread/network_diagnostic.cpp
@@ -69,16 +69,11 @@
Get<Tmf::TmfAgent>().AddResource(mDiagnosticReset);
}
-void NetworkDiagnostic::SetReceiveDiagnosticGetCallback(otReceiveDiagnosticGetCallback aCallback,
- void * aCallbackContext)
-{
- mReceiveDiagnosticGetCallback = aCallback;
- mReceiveDiagnosticGetCallbackContext = aCallbackContext;
-}
-
-otError NetworkDiagnostic::SendDiagnosticGet(const Ip6::Address &aDestination,
- const uint8_t aTlvTypes[],
- uint8_t aCount)
+otError NetworkDiagnostic::SendDiagnosticGet(const Ip6::Address & aDestination,
+ const uint8_t aTlvTypes[],
+ uint8_t aCount,
+ otReceiveDiagnosticGetCallback aCallback,
+ void * aCallbackContext)
{
otError error;
Coap::Message * message = nullptr;
@@ -122,6 +117,9 @@
SuccessOrExit(error = Get<Tmf::TmfAgent>().SendMessage(*message, messageInfo, handler, this));
+ mReceiveDiagnosticGetCallback = aCallback;
+ mReceiveDiagnosticGetCallbackContext = aCallbackContext;
+
otLogInfoNetDiag("Sent diagnostic get");
exit:
diff --git a/src/core/thread/network_diagnostic.hpp b/src/core/thread/network_diagnostic.hpp
index 510a9ea..3d7a9cf 100644
--- a/src/core/thread/network_diagnostic.hpp
+++ b/src/core/thread/network_diagnostic.hpp
@@ -82,25 +82,22 @@
explicit NetworkDiagnostic(Instance &aInstance);
/**
- * This method registers a callback to provide received raw DIAG_GET.rsp or an DIAG_GET.ans payload.
- *
- * @param[in] aCallback A pointer to a function that is called when an DIAG_GET.rsp or an DIAG_GET.ans
- * is received or nullptr to disable the callback.
- * @param[in] aCallbackContext A pointer to application-specific context.
- *
- */
- void SetReceiveDiagnosticGetCallback(otReceiveDiagnosticGetCallback aCallback, void *aCallbackContext);
-
- /**
* This method sends Diagnostic Get request. If the @p aDestination is of multicast type, the DIAG_GET.qry
* message is sent or the DIAG_GET.req otherwise.
*
- * @param[in] aDestination A reference to the destination address.
- * @param[in] aTlvTypes An array of Network Diagnostic TLV types.
- * @param[in] aCount Number of types in aTlvTypes
+ * @param[in] aDestination A reference to the destination address.
+ * @param[in] aTlvTypes An array of Network Diagnostic TLV types.
+ * @param[in] aCount Number of types in aTlvTypes.
+ * @param[in] aCallback A pointer to a function that is called when Network Diagnostic Get response
+ * is received or NULL to disable the callback.
+ * @param[in] aCallbackContext A pointer to application-specific context.
*
*/
- otError SendDiagnosticGet(const Ip6::Address &aDestination, const uint8_t aTlvTypes[], uint8_t aCount);
+ otError SendDiagnosticGet(const Ip6::Address & aDestination,
+ const uint8_t aTlvTypes[],
+ uint8_t aCount,
+ otReceiveDiagnosticGetCallback aCallback,
+ void * aCallbackContext);
/**
* This method sends Diagnostic Reset request.