Introduce mbedtls_ecjpake_write_shared_key() to export the EC J-PAKE shared key material before the KDF()

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/include/mbedtls/ecjpake.h b/include/mbedtls/ecjpake.h
index ffdea05..5b57455 100644
--- a/include/mbedtls/ecjpake.h
+++ b/include/mbedtls/ecjpake.h
@@ -259,6 +259,29 @@
                             void *p_rng );
 
 /**
+ * \brief           Write the shared key material to be passed to a Key
+ *                  Derivation Function as described in RFC8236.
+ *
+ * \param ctx       The ECJPAKE context to use. This must be initialized,
+ *                  set up and have performed both round one and two.
+ * \param buf       The buffer to write the derived secret to. This must
+ *                  be a writable buffer of length \p len Bytes.
+ * \param len       The length of \p buf in Bytes.
+ * \param olen      The address at which to store the total number of Bytes
+ *                  written to \p buf. This must not be \c NULL.
+ * \param f_rng     The RNG function to use. This must not be \c NULL.
+ * \param p_rng     The RNG parameter to be passed to \p f_rng. This
+ *                  may be \c NULL if \p f_rng doesn't use a context.
+ *
+ * \return          \c 0 if successful.
+ * \return          A negative error code on failure.
+ */
+int mbedtls_ecjpake_write_shared_key( mbedtls_ecjpake_context *ctx,
+                            unsigned char *buf, size_t len, size_t *olen,
+                            int (*f_rng)(void *, unsigned char *, size_t),
+                            void *p_rng );
+
+/**
  * \brief           This clears an ECJPAKE context and frees any
  *                  embedded data structure.
  *
diff --git a/library/ecjpake.c b/library/ecjpake.c
index 7447354..17d0438 100644
--- a/library/ecjpake.c
+++ b/library/ecjpake.c
@@ -760,22 +760,14 @@
 /*
  * Derive PMS (7.4.2.7 / 7.4.2.8)
  */
-int mbedtls_ecjpake_derive_secret( mbedtls_ecjpake_context *ctx,
-                            unsigned char *buf, size_t len, size_t *olen,
-                            int (*f_rng)(void *, unsigned char *, size_t),
-                            void *p_rng )
+static int mbedtls_ecjpake_derive_k( mbedtls_ecjpake_context *ctx,
+                                mbedtls_ecp_point *K,
+                                int (*f_rng)(void *, unsigned char *, size_t),
+                                void *p_rng )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    mbedtls_ecp_point K;
     mbedtls_mpi m_xm2_s, one;
-    unsigned char kx[MBEDTLS_ECP_MAX_BYTES];
-    size_t x_bytes;
 
-    *olen = mbedtls_hash_info_get_size( ctx->md_type );
-    if( len < *olen )
-        return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );
-
-    mbedtls_ecp_point_init( &K );
     mbedtls_mpi_init( &m_xm2_s );
     mbedtls_mpi_init( &one );
 
@@ -788,12 +780,39 @@
      */
     MBEDTLS_MPI_CHK( ecjpake_mul_secret( &m_xm2_s, -1, &ctx->xm2, &ctx->s,
                                          &ctx->grp.N, f_rng, p_rng ) );
-    MBEDTLS_MPI_CHK( mbedtls_ecp_muladd( &ctx->grp, &K,
+    MBEDTLS_MPI_CHK( mbedtls_ecp_muladd( &ctx->grp, K,
                                          &one, &ctx->Xp,
                                          &m_xm2_s, &ctx->Xp2 ) );
-    MBEDTLS_MPI_CHK( mbedtls_ecp_mul( &ctx->grp, &K, &ctx->xm2, &K,
+    MBEDTLS_MPI_CHK( mbedtls_ecp_mul( &ctx->grp, K, &ctx->xm2, K,
                                       f_rng, p_rng ) );
 
+cleanup:
+    mbedtls_mpi_free( &m_xm2_s );
+    mbedtls_mpi_free( &one );
+
+    return( ret );
+}
+
+int mbedtls_ecjpake_derive_secret( mbedtls_ecjpake_context *ctx,
+                            unsigned char *buf, size_t len, size_t *olen,
+                            int (*f_rng)(void *, unsigned char *, size_t),
+                            void *p_rng )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_ecp_point K;
+    unsigned char kx[MBEDTLS_ECP_MAX_BYTES];
+    size_t x_bytes;
+
+    *olen = mbedtls_hash_info_get_size( ctx->md_type );
+    if( len < *olen )
+        return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );
+
+    mbedtls_ecp_point_init( &K );
+
+    ret = mbedtls_ecjpake_derive_k(ctx, &K, f_rng, p_rng);
+    if( ret )
+        goto cleanup;
+
     /* PMS = SHA-256( K.X ) */
     x_bytes = ( ctx->grp.pbits + 7 ) / 8;
     MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &K.X, kx, x_bytes ) );
@@ -802,8 +821,31 @@
 
 cleanup:
     mbedtls_ecp_point_free( &K );
-    mbedtls_mpi_free( &m_xm2_s );
-    mbedtls_mpi_free( &one );
+
+    return( ret );
+}
+
+int mbedtls_ecjpake_write_shared_key( mbedtls_ecjpake_context *ctx,
+                            unsigned char *buf, size_t len, size_t *olen,
+                            int (*f_rng)(void *, unsigned char *, size_t),
+                            void *p_rng )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_ecp_point K;
+
+    mbedtls_ecp_point_init( &K );
+
+    ret = mbedtls_ecjpake_derive_k(ctx, &K, f_rng, p_rng);
+    if( ret )
+        goto cleanup;
+
+    ret = mbedtls_ecp_point_write_binary( &ctx->grp, &K, ctx->point_format,
+                                          olen, buf, len );
+    if( ret != 0 )
+        goto cleanup;
+
+cleanup:
+    mbedtls_ecp_point_free( &K );
 
     return( ret );
 }
@@ -958,6 +1000,15 @@
     0xcc, 0x38, 0xdb, 0xdc, 0xae, 0x60, 0xd9, 0xc5, 0x4c
 };
 
+static const unsigned char ecjpake_test_shared_key[] = {
+    0x04, 0x01, 0xab, 0xe9, 0xf2, 0xc7, 0x3a, 0x99, 0x14, 0xcb, 0x1f, 0x80,
+    0xfb, 0x9d, 0xdb, 0x7e, 0x00, 0x12, 0xa8, 0x9c, 0x2f, 0x39, 0x27, 0x79,
+    0xf9, 0x64, 0x40, 0x14, 0x75, 0xea, 0xc1, 0x31, 0x28, 0x43, 0x8f, 0xe1,
+    0x12, 0x41, 0xd6, 0xc1, 0xe5, 0x5f, 0x7b, 0x80, 0x88, 0x94, 0xc9, 0xc0,
+    0x27, 0xa3, 0x34, 0x41, 0xf5, 0xcb, 0xa1, 0xfe, 0x6c, 0xc7, 0xe6, 0x12,
+    0x17, 0xc3, 0xde, 0x27, 0xb4,
+};
+
 static const unsigned char ecjpake_test_pms[] = {
     0xf3, 0xd4, 0x7f, 0x59, 0x98, 0x44, 0xdb, 0x92, 0xa5, 0x69, 0xbb, 0xe7,
     0x98, 0x1e, 0x39, 0xd9, 0x31, 0xfd, 0x74, 0x3b, 0xf2, 0x2e, 0x98, 0xf9,
@@ -1144,6 +1195,13 @@
     TEST_ASSERT( len == sizeof( ecjpake_test_pms ) );
     TEST_ASSERT( memcmp( buf, ecjpake_test_pms, len ) == 0 );
 
+    /* Server derives K as unsigned binary data */
+    TEST_ASSERT( mbedtls_ecjpake_write_shared_key( &srv,
+                 buf, sizeof( buf ), &len, ecjpake_lgc, NULL ) == 0 );
+
+    TEST_ASSERT( len == sizeof( ecjpake_test_shared_key ) );
+    TEST_ASSERT( memcmp( buf, ecjpake_test_shared_key, len ) == 0 );
+
     memset( buf, 0, len ); /* Avoid interferences with next step */
 
     /* Client derives PMS */
@@ -1153,6 +1211,13 @@
     TEST_ASSERT( len == sizeof( ecjpake_test_pms ) );
     TEST_ASSERT( memcmp( buf, ecjpake_test_pms, len ) == 0 );
 
+    /* Client derives K as unsigned binary data */
+    TEST_ASSERT( mbedtls_ecjpake_write_shared_key( &cli,
+                 buf, sizeof( buf ), &len, ecjpake_lgc, NULL ) == 0 );
+
+    TEST_ASSERT( len == sizeof( ecjpake_test_shared_key ) );
+    TEST_ASSERT( memcmp( buf, ecjpake_test_shared_key, len ) == 0 );
+
     if( verbose != 0 )
         mbedtls_printf( "passed\n" );
 #endif /* ! MBEDTLS_ECJPAKE_ALT */