[netdiag] fix setting DIAG_GET response callback (#6192)

This commit sets the DIAG_GET response callback every time the
otThreadSendDiagnosticGet is called. This is because the callback may
be changed to other modules (e,g. OTBR_REST, OTBR_UBUS).
diff --git a/include/openthread/instance.h b/include/openthread/instance.h
index adb5502..186edd0 100644
--- a/include/openthread/instance.h
+++ b/include/openthread/instance.h
@@ -53,7 +53,7 @@
  * @note This number versions both OpenThread platform and user APIs.
  *
  */
-#define OPENTHREAD_API_VERSION (78)
+#define OPENTHREAD_API_VERSION (79)
 
 /**
  * @addtogroup api-instance
diff --git a/include/openthread/netdiag.h b/include/openthread/netdiag.h
index 376780f..2d291a0 100644
--- a/include/openthread/netdiag.h
+++ b/include/openthread/netdiag.h
@@ -297,34 +297,26 @@
                                                void *               aContext);
 
 /**
- * This function registers a callback to provide received raw Network Diagnostic Get response payload.
+ * Send a Network Diagnostic Get request.
  *
  * @param[in]  aInstance         A pointer to an OpenThread instance.
+ * @param[in]  aDestination      A pointer to destination address.
+ * @param[in]  aTlvTypes         An array of Network Diagnostic TLV types.
+ * @param[in]  aCount            Number of types in aTlvTypes.
  * @param[in]  aCallback         A pointer to a function that is called when Network Diagnostic Get response
  *                               is received or NULL to disable the callback.
  * @param[in]  aCallbackContext  A pointer to application-specific context.
  *
- */
-void otThreadSetReceiveDiagnosticGetCallback(otInstance *                   aInstance,
-                                             otReceiveDiagnosticGetCallback aCallback,
-                                             void *                         aCallbackContext);
-
-/**
- * Send a Network Diagnostic Get request.
- *
- * @param[in]  aInstance      A pointer to an OpenThread instance.
- * @param[in]  aDestination   A pointer to destination address.
- * @param[in]  aTlvTypes      An array of Network Diagnostic TLV types.
- * @param[in]  aCount         Number of types in aTlvTypes.
- *
  * @retval OT_ERROR_NONE    Successfully queued the DIAG_GET.req.
  * @retval OT_ERROR_NO_BUFS Insufficient message buffers available to send DIAG_GET.req.
  *
  */
-otError otThreadSendDiagnosticGet(otInstance *        aInstance,
-                                  const otIp6Address *aDestination,
-                                  const uint8_t       aTlvTypes[],
-                                  uint8_t             aCount);
+otError otThreadSendDiagnosticGet(otInstance *                   aInstance,
+                                  const otIp6Address *           aDestination,
+                                  const uint8_t                  aTlvTypes[],
+                                  uint8_t                        aCount,
+                                  otReceiveDiagnosticGetCallback aCallback,
+                                  void *                         aCallbackContext);
 
 /**
  * Send a Network Diagnostic Reset request.
diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp
index f7dcbb9..16eecdc 100644
--- a/src/cli/cli.cpp
+++ b/src/cli/cli.cpp
@@ -147,9 +147,6 @@
     , mSrpServer(*this)
 #endif
 {
-#if OPENTHREAD_FTD || OPENTHREAD_CONFIG_TMF_NETWORK_DIAG_MTD_ENABLE
-    otThreadSetReceiveDiagnosticGetCallback(mInstance, &Interpreter::HandleDiagnosticGetResponse, this);
-#endif
 #if OPENTHREAD_FTD
     otThreadSetDiscoveryRequestCallback(mInstance, &Interpreter::HandleDiscoveryRequest, this);
 #endif
@@ -4850,7 +4847,8 @@
 
     if (strcmp(aArgs[0], "get") == 0)
     {
-        IgnoreError(otThreadSendDiagnosticGet(mInstance, &address, tlvTypes, count));
+        SuccessOrExit(error = otThreadSendDiagnosticGet(mInstance, &address, tlvTypes, count,
+                                                        &Interpreter::HandleDiagnosticGetResponse, this));
         ExitNow(error = OT_ERROR_PENDING);
     }
     else if (strcmp(aArgs[0], "reset") == 0)
diff --git a/src/core/api/netdiag_api.cpp b/src/core/api/netdiag_api.cpp
index 5b75e0e..3e8b836 100644
--- a/src/core/api/netdiag_api.cpp
+++ b/src/core/api/netdiag_api.cpp
@@ -49,24 +49,17 @@
                                                                 *aIterator, *aNetworkDiagTlv);
 }
 
-void otThreadSetReceiveDiagnosticGetCallback(otInstance *                   aInstance,
-                                             otReceiveDiagnosticGetCallback aCallback,
-                                             void *                         aCallbackContext)
-{
-    Instance &instance = *static_cast<Instance *>(aInstance);
-
-    instance.Get<NetworkDiagnostic::NetworkDiagnostic>().SetReceiveDiagnosticGetCallback(aCallback, aCallbackContext);
-}
-
-otError otThreadSendDiagnosticGet(otInstance *        aInstance,
-                                  const otIp6Address *aDestination,
-                                  const uint8_t       aTlvTypes[],
-                                  uint8_t             aCount)
+otError otThreadSendDiagnosticGet(otInstance *                   aInstance,
+                                  const otIp6Address *           aDestination,
+                                  const uint8_t                  aTlvTypes[],
+                                  uint8_t                        aCount,
+                                  otReceiveDiagnosticGetCallback aCallback,
+                                  void *                         aCallbackContext)
 {
     Instance &instance = *static_cast<Instance *>(aInstance);
 
     return instance.Get<NetworkDiagnostic::NetworkDiagnostic>().SendDiagnosticGet(
-        *static_cast<const Ip6::Address *>(aDestination), aTlvTypes, aCount);
+        *static_cast<const Ip6::Address *>(aDestination), aTlvTypes, aCount, aCallback, aCallbackContext);
 }
 
 otError otThreadSendDiagnosticReset(otInstance *        aInstance,
diff --git a/src/core/thread/network_diagnostic.cpp b/src/core/thread/network_diagnostic.cpp
index a26525b..0365549 100644
--- a/src/core/thread/network_diagnostic.cpp
+++ b/src/core/thread/network_diagnostic.cpp
@@ -69,16 +69,11 @@
     Get<Tmf::TmfAgent>().AddResource(mDiagnosticReset);
 }
 
-void NetworkDiagnostic::SetReceiveDiagnosticGetCallback(otReceiveDiagnosticGetCallback aCallback,
-                                                        void *                         aCallbackContext)
-{
-    mReceiveDiagnosticGetCallback        = aCallback;
-    mReceiveDiagnosticGetCallbackContext = aCallbackContext;
-}
-
-otError NetworkDiagnostic::SendDiagnosticGet(const Ip6::Address &aDestination,
-                                             const uint8_t       aTlvTypes[],
-                                             uint8_t             aCount)
+otError NetworkDiagnostic::SendDiagnosticGet(const Ip6::Address &           aDestination,
+                                             const uint8_t                  aTlvTypes[],
+                                             uint8_t                        aCount,
+                                             otReceiveDiagnosticGetCallback aCallback,
+                                             void *                         aCallbackContext)
 {
     otError               error;
     Coap::Message *       message = nullptr;
@@ -122,6 +117,9 @@
 
     SuccessOrExit(error = Get<Tmf::TmfAgent>().SendMessage(*message, messageInfo, handler, this));
 
+    mReceiveDiagnosticGetCallback        = aCallback;
+    mReceiveDiagnosticGetCallbackContext = aCallbackContext;
+
     otLogInfoNetDiag("Sent diagnostic get");
 
 exit:
diff --git a/src/core/thread/network_diagnostic.hpp b/src/core/thread/network_diagnostic.hpp
index 510a9ea..3d7a9cf 100644
--- a/src/core/thread/network_diagnostic.hpp
+++ b/src/core/thread/network_diagnostic.hpp
@@ -82,25 +82,22 @@
     explicit NetworkDiagnostic(Instance &aInstance);
 
     /**
-     * This method registers a callback to provide received raw DIAG_GET.rsp or an DIAG_GET.ans payload.
-     *
-     * @param[in]  aCallback         A pointer to a function that is called when an DIAG_GET.rsp or an DIAG_GET.ans
-     *                               is received or nullptr to disable the callback.
-     * @param[in]  aCallbackContext  A pointer to application-specific context.
-     *
-     */
-    void SetReceiveDiagnosticGetCallback(otReceiveDiagnosticGetCallback aCallback, void *aCallbackContext);
-
-    /**
      * This method sends Diagnostic Get request. If the @p aDestination is of multicast type, the DIAG_GET.qry
      * message is sent or the DIAG_GET.req otherwise.
      *
-     * @param[in] aDestination  A reference to the destination address.
-     * @param[in] aTlvTypes     An array of Network Diagnostic TLV types.
-     * @param[in] aCount        Number of types in aTlvTypes
+     * @param[in]  aDestination      A reference to the destination address.
+     * @param[in]  aTlvTypes         An array of Network Diagnostic TLV types.
+     * @param[in]  aCount            Number of types in aTlvTypes.
+     * @param[in]  aCallback         A pointer to a function that is called when Network Diagnostic Get response
+     *                               is received or NULL to disable the callback.
+     * @param[in]  aCallbackContext  A pointer to application-specific context.
      *
      */
-    otError SendDiagnosticGet(const Ip6::Address &aDestination, const uint8_t aTlvTypes[], uint8_t aCount);
+    otError SendDiagnosticGet(const Ip6::Address &           aDestination,
+                              const uint8_t                  aTlvTypes[],
+                              uint8_t                        aCount,
+                              otReceiveDiagnosticGetCallback aCallback,
+                              void *                         aCallbackContext);
 
     /**
      * This method sends Diagnostic Reset request.