diff options
Diffstat (limited to 'feature-add-SMx-support.patch')
-rw-r--r-- | feature-add-SMx-support.patch | 1507 |
1 files changed, 1507 insertions, 0 deletions
diff --git a/feature-add-SMx-support.patch b/feature-add-SMx-support.patch new file mode 100644 index 0000000..e8b0f3c --- /dev/null +++ b/feature-add-SMx-support.patch @@ -0,0 +1,1507 @@ +From d2e28809c673f914b49147ca3fa31e08b9e885d7 Mon Sep 17 00:00:00 2001 +From: renmingshuai <renmingshuai@huawei.com> +Date: Sat, 29 Jul 2023 10:50:29 +0800 +Subject: [PATCH] feature add sm2 + +--- + Makefile.in | 4 +- + authfd.c | 2 + + authfile.c | 1 + + cipher.c | 1 + + digest-openssl.c | 1 + + digest.h | 3 +- + kex.c | 1 + + kex.h | 3 + + kexecdh.c | 23 +- + kexgen.c | 3 + + kexsm2.c | 406 ++++++++++++++++++++++++++ + mac.c | 1 + + pathnames.h | 1 + + regress/agent.sh | 9 + + regress/keytype.sh | 2 + + regress/knownhosts-command.sh | 1 + + regress/misc/fuzz-harness/sig_fuzz.cc | 4 + + regress/unittests/kex/test_kex.c | 3 + + ssh-ecdsa.c | 6 +- + ssh-keygen.c | 12 +- + ssh-keyscan.c | 12 +- + ssh-sm2.c | 381 ++++++++++++++++++++++++ + ssh_api.c | 2 + + sshconnect2.c | 1 + + sshd.c | 7 + + sshkey.c | 21 ++ + sshkey.h | 2 + + 27 files changed, 899 insertions(+), 14 deletions(-) + create mode 100644 kexsm2.c + create mode 100644 ssh-sm2.c + +diff --git a/Makefile.in b/Makefile.in +index 5fec5b3..7dcda3e 100644 +--- a/Makefile.in ++++ b/Makefile.in +@@ -102,14 +102,14 @@ LIBSSH_OBJS=${LIBOPENSSH_OBJS} \ + log.o match.o moduli.o nchan.o packet.o \ + readpass.o ttymodes.o xmalloc.o addr.o addrmatch.o \ + atomicio.o dispatch.o mac.o misc.o utf8.o \ +- monitor_fdpass.o rijndael.o ssh-dss.o ssh-ecdsa.o ssh-ecdsa-sk.o \ ++ monitor_fdpass.o rijndael.o ssh-dss.o ssh-ecdsa.o ssh-sm2.o ssh-ecdsa-sk.o \ + ssh-ed25519-sk.o ssh-rsa.o dh.o \ + msg.o progressmeter.o dns.o entropy.o gss-genr.o umac.o umac128.o \ + ssh-pkcs11.o ssh-pkcs11-uri.o smult_curve25519_ref.o \ + poly1305.o chacha.o cipher-chachapoly.o cipher-chachapoly-libcrypto.o \ + ssh-ed25519.o digest-openssl.o digest-libc.o \ + hmac.o ed25519.o hash.o \ +- kex.o kexdh.o kexgex.o kexecdh.o kexc25519.o \ ++ kex.o kexdh.o kexgex.o kexecdh.o kexc25519.o kexsm2.o \ + kexgexc.o kexgexs.o \ + kexsntrup761x25519.o sntrup761.o kexgen.o \ + kexgssc.o \ +diff --git a/authfd.c b/authfd.c +index 25a3636..bcc25a7 100644 +--- a/authfd.c ++++ b/authfd.c +@@ -583,6 +583,8 @@ ssh_add_identity_constrained(int sock, struct sshkey *key, + case KEY_DSA_CERT: + case KEY_ECDSA: + case KEY_ECDSA_CERT: ++ case KEY_SM2: ++ case KEY_SM2_CERT: + case KEY_ECDSA_SK: + case KEY_ECDSA_SK_CERT: + #endif +diff --git a/authfile.c b/authfile.c +index 445f2dd..3884031 100644 +--- a/authfile.c ++++ b/authfile.c +@@ -332,6 +332,7 @@ sshkey_load_private_cert(int type, const char *filename, const char *passphrase, + case KEY_RSA: + case KEY_DSA: + case KEY_ECDSA: ++ case KEY_SM2: + #endif /* WITH_OPENSSL */ + case KEY_ED25519: + case KEY_XMSS: +diff --git a/cipher.c b/cipher.c +index 609450d..7f98413 100644 +--- a/cipher.c ++++ b/cipher.c +@@ -86,6 +86,7 @@ static const struct sshcipher ciphers[] = { + #endif + { "chacha20-poly1305@openssh.com", + 8, 64, 0, 16, CFLAG_CHACHAPOLY, NULL }, ++ { "sm4-ctr", 16, 16, 0, 0, 0, EVP_sm4_ctr }, + { "none", 8, 0, 0, 0, CFLAG_NONE, NULL }, + + { NULL, 0, 0, 0, 0, 0, NULL } +diff --git a/digest-openssl.c b/digest-openssl.c +index 94730e9..fa92360 100644 +--- a/digest-openssl.c ++++ b/digest-openssl.c +@@ -61,6 +61,7 @@ const struct ssh_digest digests[] = { + { SSH_DIGEST_SHA256, "SHA256", 32, EVP_sha256 }, + { SSH_DIGEST_SHA384, "SHA384", 48, EVP_sha384 }, + { SSH_DIGEST_SHA512, "SHA512", 64, EVP_sha512 }, ++ { SSH_DIGEST_SM3, "SM3", 32, EVP_sm3 }, + { -1, NULL, 0, NULL }, + }; + +diff --git a/digest.h b/digest.h +index c7ceeb3..520722c 100644 +--- a/digest.h ++++ b/digest.h +@@ -27,7 +27,8 @@ + #define SSH_DIGEST_SHA256 2 + #define SSH_DIGEST_SHA384 3 + #define SSH_DIGEST_SHA512 4 +-#define SSH_DIGEST_MAX 5 ++#define SSH_DIGEST_SM3 5 ++#define SSH_DIGEST_MAX 6 + + struct sshbuf; + struct ssh_digest_ctx; +diff --git a/kex.c b/kex.c +index 0fbd0ca..e9dfcc2 100644 +--- a/kex.c ++++ b/kex.c +@@ -125,6 +125,7 @@ static const struct kexalg kexalgs[] = { + SSH_DIGEST_SHA512 }, + #endif + #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */ ++ { "sm2-sm3", KEX_SM2_SM3, NID_sm2, SSH_DIGEST_SM3 }, + { NULL, 0, -1, -1}, + }; + static const struct kexalg gss_kexalgs[] = { +diff --git a/kex.h b/kex.h +index 0fac9d3..044ec18 100644 +--- a/kex.h ++++ b/kex.h +@@ -102,6 +102,7 @@ enum kex_exchange { + KEX_ECDH_SHA2, + KEX_C25519_SHA256, + KEX_KEM_SNTRUP761X25519_SHA512, ++ KEX_SM2_SM3, + #ifdef GSSAPI + KEX_GSS_GRP1_SHA1, + KEX_GSS_GRP14_SHA1, +@@ -287,6 +288,8 @@ int kexc25519_shared_key_ext(const u_char key[CURVE25519_SIZE], + __attribute__((__bounded__(__minbytes__, 1, CURVE25519_SIZE))) + __attribute__((__bounded__(__minbytes__, 2, CURVE25519_SIZE))); + ++int SM2KAP_compute_key(void *out, size_t outlen, const EC_POINT *pub_key, const EC_KEY *eckey, int server); ++ + #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH) + void dump_digest(const char *, const u_char *, int); + #endif +diff --git a/kexecdh.c b/kexecdh.c +index efb2e55..69ec13b 100644 +--- a/kexecdh.c ++++ b/kexecdh.c +@@ -44,7 +44,7 @@ + + static int + kex_ecdh_dec_key_group(struct kex *, const struct sshbuf *, EC_KEY *key, +- const EC_GROUP *, struct sshbuf **); ++ const EC_GROUP *, struct sshbuf **, int server); + + int + kex_ecdh_keypair(struct kex *kex) +@@ -124,7 +124,7 @@ kex_ecdh_enc(struct kex *kex, const struct sshbuf *client_blob, + (r = sshbuf_get_u32(server_blob, NULL)) != 0) + goto out; + if ((r = kex_ecdh_dec_key_group(kex, client_blob, server_key, group, +- shared_secretp)) != 0) ++ shared_secretp, 1)) != 0) + goto out; + *server_blobp = server_blob; + server_blob = NULL; +@@ -136,7 +136,7 @@ kex_ecdh_enc(struct kex *kex, const struct sshbuf *client_blob, + + static int + kex_ecdh_dec_key_group(struct kex *kex, const struct sshbuf *ec_blob, +- EC_KEY *key, const EC_GROUP *group, struct sshbuf **shared_secretp) ++ EC_KEY *key, const EC_GROUP *group, struct sshbuf **shared_secretp, int server) + { + struct sshbuf *buf = NULL; + BIGNUM *shared_secret = NULL; +@@ -176,11 +176,20 @@ kex_ecdh_dec_key_group(struct kex *kex, const struct sshbuf *ec_blob, + r = SSH_ERR_ALLOC_FAIL; + goto out; + } +- if (ECDH_compute_key(kbuf, klen, dh_pub, key, NULL) != (int)klen || ++ if (kex->ec_nid == NID_sm2) { ++ if (SM2KAP_compute_key(kbuf, klen, dh_pub, key, server) != (int)klen || + BN_bin2bn(kbuf, klen, shared_secret) == NULL) { +- r = SSH_ERR_LIBCRYPTO_ERROR; +- goto out; ++ r = SSH_ERR_LIBCRYPTO_ERROR; ++ goto out; ++ } ++ } else { ++ if (ECDH_compute_key(kbuf, klen, dh_pub, key, NULL) != (int)klen || ++ BN_bin2bn(kbuf, klen, shared_secret) == NULL) { ++ r = SSH_ERR_LIBCRYPTO_ERROR; ++ goto out; ++ } + } ++ + #ifdef DEBUG_KEXECDH + dump_digest("shared secret", kbuf, klen); + #endif +@@ -203,7 +212,7 @@ kex_ecdh_dec(struct kex *kex, const struct sshbuf *server_blob, + int r; + + r = kex_ecdh_dec_key_group(kex, server_blob, kex->ec_client_key, +- kex->ec_group, shared_secretp); ++ kex->ec_group, shared_secretp, 0); + EC_KEY_free(kex->ec_client_key); + kex->ec_client_key = NULL; + return r; +diff --git a/kexgen.c b/kexgen.c +index ca70484..4855d5c 100644 +--- a/kexgen.c ++++ b/kexgen.c +@@ -111,6 +111,7 @@ kex_gen_client(struct ssh *ssh) + r = kex_dh_keypair(kex); + break; + case KEX_ECDH_SHA2: ++ case KEX_SM2_SM3: + r = kex_ecdh_keypair(kex); + break; + #endif +@@ -182,6 +183,7 @@ input_kex_gen_reply(int type, u_int32_t seq, struct ssh *ssh) + r = kex_dh_dec(kex, server_blob, &shared_secret); + break; + case KEX_ECDH_SHA2: ++ case KEX_SM2_SM3: + r = kex_ecdh_dec(kex, server_blob, &shared_secret); + break; + #endif +@@ -298,6 +300,7 @@ input_kex_gen_init(int type, u_int32_t seq, struct ssh *ssh) + &shared_secret); + break; + case KEX_ECDH_SHA2: ++ case KEX_SM2_SM3: + r = kex_ecdh_enc(kex, client_pubkey, &server_pubkey, + &shared_secret); + break; +diff --git a/kexsm2.c b/kexsm2.c +new file mode 100644 +index 0000000..f507557 +--- /dev/null ++++ b/kexsm2.c +@@ -0,0 +1,406 @@ ++#include <openssl/err.h> ++#include <openssl/evp.h> ++#include <openssl/bn.h> ++#include <string.h> ++#include <openssl/ecdh.h> ++#include <openssl/ec.h> ++ ++int sm2_compute_z_digest(uint8_t *out, ++ const EVP_MD *digest, ++ const uint8_t *id, ++ const size_t id_len, ++ const EC_KEY *key) ++{ ++ int rc = 0; ++ const EC_GROUP *group = EC_KEY_get0_group(key); ++ BN_CTX *ctx = NULL; ++ EVP_MD_CTX *hash = NULL; ++ BIGNUM *p = NULL; ++ BIGNUM *a = NULL; ++ BIGNUM *b = NULL; ++ BIGNUM *xG = NULL; ++ BIGNUM *yG = NULL; ++ BIGNUM *xA = NULL; ++ BIGNUM *yA = NULL; ++ int p_bytes = 0; ++ uint8_t *buf = NULL; ++ uint16_t entl = 0; ++ uint8_t e_byte = 0; ++ ++ hash = EVP_MD_CTX_new(); ++ ctx = BN_CTX_new(); ++ if (hash == NULL || ctx == NULL) { ++ goto done; ++ } ++ ++ p = BN_CTX_get(ctx); ++ a = BN_CTX_get(ctx); ++ b = BN_CTX_get(ctx); ++ xG = BN_CTX_get(ctx); ++ yG = BN_CTX_get(ctx); ++ xA = BN_CTX_get(ctx); ++ yA = BN_CTX_get(ctx); ++ ++ if (yA == NULL) { ++ goto done; ++ } ++ ++ if (!EVP_DigestInit(hash, digest)) { ++ goto done; ++ } ++ ++ /* Z = h(ENTL || ID || a || b || xG || yG || xA || yA) */ ++ ++ if (id_len >= (UINT16_MAX / 8)) { ++ /* too large */ ++ goto done; ++ } ++ ++ entl = (uint16_t)(8 * id_len); ++ ++ e_byte = entl >> 8; ++ if (!EVP_DigestUpdate(hash, &e_byte, 1)) { ++ goto done; ++ } ++ e_byte = entl & 0xFF; ++ if (!EVP_DigestUpdate(hash, &e_byte, 1)) { ++ goto done; ++ } ++ ++ if (id_len > 0 && !EVP_DigestUpdate(hash, id, id_len)) { ++ goto done; ++ } ++ ++ if (!EC_GROUP_get_curve(group, p, a, b, ctx)) { ++ goto done; ++ } ++ ++ p_bytes = BN_num_bytes(p); ++ buf = OPENSSL_zalloc(p_bytes); ++ if (buf == NULL) { ++ goto done; ++ } ++ ++ if (BN_bn2binpad(a, buf, p_bytes) < 0 ++ || !EVP_DigestUpdate(hash, buf, p_bytes) ++ || BN_bn2binpad(b, buf, p_bytes) < 0 ++ || !EVP_DigestUpdate(hash, buf, p_bytes) ++ || !EC_POINT_get_affine_coordinates(group, ++ EC_GROUP_get0_generator(group), ++ xG, yG, ctx) ++ || BN_bn2binpad(xG, buf, p_bytes) < 0 ++ || !EVP_DigestUpdate(hash, buf, p_bytes) ++ || BN_bn2binpad(yG, buf, p_bytes) < 0 ++ || !EVP_DigestUpdate(hash, buf, p_bytes) ++ || !EC_POINT_get_affine_coordinates(group, ++ EC_KEY_get0_public_key(key), ++ xA, yA, ctx) ++ || BN_bn2binpad(xA, buf, p_bytes) < 0 ++ || !EVP_DigestUpdate(hash, buf, p_bytes) ++ || BN_bn2binpad(yA, buf, p_bytes) < 0 ++ || !EVP_DigestUpdate(hash, buf, p_bytes) ++ || !EVP_DigestFinal(hash, out, NULL)) { ++ goto done; ++ } ++ ++ rc = 1; ++ ++ done: ++ OPENSSL_free(buf); ++ BN_CTX_free(ctx); ++ EVP_MD_CTX_free(hash); ++ return rc; ++} ++ ++ ++/* GM/T003_2012 Defined Key Derive Function */ ++int kdf_gmt003_2012(unsigned char *out, size_t outlen, const unsigned char *Z, size_t Zlen, const unsigned char *SharedInfo, size_t SharedInfolen, const EVP_MD *md) ++{ ++ EVP_MD_CTX *mctx = NULL; ++ unsigned int counter; ++ unsigned char ctr[4]; ++ size_t mdlen; ++ int retval = 0; ++ unsigned char dgst[EVP_MAX_MD_SIZE]; ++ ++ if (!out || !outlen) return retval; ++ if (md == NULL) { ++ md = EVP_sm3(); ++ } ++ mdlen = EVP_MD_size(md); ++ mctx = EVP_MD_CTX_new(); ++ if (mctx == NULL) { ++ goto err; ++ } ++ ++ for (counter = 1;; counter++) { ++ if (!EVP_DigestInit(mctx, md)) { ++ goto err; ++ } ++ ctr[0] = (unsigned char)((counter >> 24) & 0xFF); ++ ctr[1] = (unsigned char)((counter >> 16) & 0xFF); ++ ctr[2] = (unsigned char)((counter >> 8) & 0xFF); ++ ctr[3] = (unsigned char)(counter & 0xFF); ++ ++ if (!EVP_DigestUpdate(mctx, Z, Zlen)) { ++ goto err; ++ } ++ if (!EVP_DigestUpdate(mctx, ctr, sizeof(ctr))) { ++ goto err; ++ } ++ if (!EVP_DigestUpdate(mctx, SharedInfo, SharedInfolen)) { ++ goto err; ++ } ++ if (!EVP_DigestFinal(mctx, dgst, NULL)) { ++ goto err; ++ } ++ ++ if (outlen > mdlen) { ++ memcpy(out, dgst, mdlen); ++ out += mdlen; ++ outlen -= mdlen; ++ } else { ++ memcpy(out, dgst, outlen); ++ memset(dgst, 0, mdlen); ++ break; ++ } ++ } ++ ++ retval = 1; ++ ++err: ++ EVP_MD_CTX_free(mctx); ++ return retval; ++} ++ ++int sm2_kap_compute_key(void *out, size_t outlen, int server,\ ++ const uint8_t *peer_uid, int peer_uid_len, const uint8_t *self_uid, int self_uid_len, \ ++ const EC_KEY *peer_ecdhe_key, const EC_KEY *self_ecdhe_key, const EC_KEY *peer_pub_key, const EC_KEY *self_eckey, \ ++ const EVP_MD *md) ++{ ++ BN_CTX *ctx = NULL; ++ EC_POINT *UorV = NULL; ++ const EC_POINT *Rs, *Rp; ++ BIGNUM *Xs = NULL, *Xp = NULL, *h = NULL, *t = NULL, *two_power_w = NULL, *order = NULL; ++ const BIGNUM *priv_key, *r; ++ const EC_GROUP *group; ++ int w; ++ int ret = -1; ++ size_t buflen, len; ++ unsigned char *buf = NULL; ++ ++ if (outlen > INT_MAX) { ++ goto err; ++ } ++ ++ if (!peer_pub_key || !self_eckey) { ++ goto err; ++ } ++ ++ priv_key = EC_KEY_get0_private_key(self_eckey); ++ if (!priv_key) { ++ goto err; ++ } ++ ++ if (!peer_ecdhe_key || !self_ecdhe_key) { ++ goto err; ++ } ++ ++ Rs = EC_KEY_get0_public_key(self_ecdhe_key); ++ Rp = EC_KEY_get0_public_key(peer_ecdhe_key); ++ r = EC_KEY_get0_private_key(self_ecdhe_key); ++ ++ if (!Rs || !Rp || !r) { ++ goto err; ++ } ++ ++ ctx = BN_CTX_new(); ++ Xs = BN_new(); ++ Xp = BN_new(); ++ h = BN_new(); ++ t = BN_new(); ++ two_power_w = BN_new(); ++ order = BN_new(); ++ if (!Xs || !Xp || !h || !t || !two_power_w || !order) { ++ goto err; ++ } ++ ++ group = EC_KEY_get0_group(self_eckey); ++ ++ /*Second: Caculate -- w*/ ++ if (!EC_GROUP_get_order(group, order, ctx) || !EC_GROUP_get_cofactor(group, h, ctx)) { ++ goto err; ++ } ++ ++ w = (BN_num_bits(order) + 1) / 2 - 1; ++ if (!BN_lshift(two_power_w, BN_value_one(), w)) { ++ goto err; ++ } ++ ++ /*Third: Caculate -- X = 2 ^ w + (x & (2 ^ w - 1)) = 2 ^ w + (x mod 2 ^ w)*/ ++ UorV = EC_POINT_new(group); ++ ++ if (!UorV) { ++ goto err; ++ } ++ ++ /*Test peer public key On curve*/ ++ if (!EC_POINT_is_on_curve(group, Rp, ctx)) { ++ goto err; ++ } ++ ++ /*Get x*/ ++ if (EC_METHOD_get_field_type(EC_GROUP_method_of(group)) == NID_X9_62_prime_field) { ++ if (!EC_POINT_get_affine_coordinates_GFp(group, Rs, Xs, NULL, ctx)) { ++ goto err; ++ } ++ ++ if (!EC_POINT_get_affine_coordinates_GFp(group, Rp, Xp, NULL, ctx)) { ++ goto err; ++ } ++ } ++ ++ /*x mod 2 ^ w*/ ++ /*Caculate Self x*/ ++ if (!BN_nnmod(Xs, Xs, two_power_w, ctx)) { ++ goto err; ++ } ++ ++ if (!BN_add(Xs, Xs, two_power_w)) { ++ goto err; ++ } ++ ++ /*Caculate Peer x*/ ++ if (!BN_nnmod(Xp, Xp, two_power_w, ctx)) { ++ goto err; ++ } ++ ++ if (!BN_add(Xp, Xp, two_power_w)) { ++ goto err; ++ } ++ ++ /*Forth: Caculate t*/ ++ if (!BN_mod_mul(t, Xs, r, order, ctx)) { ++ goto err; ++ } ++ ++ if (!BN_mod_add(t, t, priv_key, order, ctx)) { ++ goto err; ++ } ++ ++ /*Fifth: Caculate V or U*/ ++ if (!BN_mul(t, t, h, ctx)) { ++ goto err; ++ } ++ ++ /* [x]R */ ++ if (!EC_POINT_mul(group, UorV, NULL, Rp, Xp, ctx)) { ++ goto err; ++ } ++ ++ /* P + [x]R */ ++ if (!EC_POINT_add(group, UorV, UorV, EC_KEY_get0_public_key(peer_pub_key), ctx)) { ++ goto err; ++ } ++ ++ if (!EC_POINT_mul(group, UorV, NULL, UorV, t, ctx)) { ++ goto err; ++ } ++ ++ /* Detect UorV is in */ ++ if (EC_POINT_is_at_infinity(group, UorV)) { ++ goto err; ++ } ++ ++ /*Sixth: Caculate Key -- Need Xuorv, Yuorv, Zc, Zs, klen*/ ++ { ++ /* ++ size_t buflen, len; ++ unsigned char *buf = NULL; ++ */ ++ size_t elemet_len, idx; ++ ++ elemet_len = (size_t)((EC_GROUP_get_degree(group) + 7) / 8); ++ buflen = elemet_len * 2 + 32 * 2 + 1; /*add 1 byte tag*/ ++ buf = (unsigned char *)OPENSSL_malloc(buflen + 10); ++ if (!buf) { ++ goto err; ++ } ++ memset(buf, 0, buflen + 10); ++ /*1 : Get public key for UorV, Notice: the first byte is a tag, not a valid char*/ ++ idx = EC_POINT_point2oct(group, UorV, 4, buf, buflen, ctx); ++ if (!idx) { ++ goto err; ++ } ++ ++ if (!server) { ++ /*SIDE A*/ ++ len = buflen - idx; ++ if (!sm2_compute_z_digest( (unsigned char *)(buf + idx), md, (const uint8_t *)self_uid, self_uid_len, self_eckey)) { ++ goto err; ++ } ++ len = 32; ++ idx += len; ++ } ++ ++ /*Caculate Peer Z*/ ++ len = buflen - idx; ++ if (!sm2_compute_z_digest( (unsigned char *)(buf + idx), md, (const uint8_t *)peer_uid, peer_uid_len, peer_pub_key)) { ++ goto err; ++ } ++ len = 32; ++ idx += len; ++ ++ if (server) { ++ /*SIDE B*/ ++ len = buflen - idx; ++ if (!sm2_compute_z_digest( (unsigned char *)(buf + idx), md, (const uint8_t *)self_uid, self_uid_len, self_eckey)) { ++ goto err; ++ } ++ len = 32; ++ idx += len; ++ } ++ ++ len = outlen; ++ if (!kdf_gmt003_2012(out, len, (const unsigned char *)(buf + 1), idx - 1, NULL, 0, md)) { ++ goto err; ++ } ++ } ++ ++ ret = outlen; ++ ++err: ++ if (Xs) BN_free(Xs); ++ if (Xp) BN_free(Xp); ++ if (h) BN_free(h); ++ if (t) BN_free(t); ++ if (two_power_w) BN_free(two_power_w); ++ if (order) BN_free(order); ++ if (UorV) EC_POINT_free(UorV); ++ if (buf) OPENSSL_free(buf); ++ if (ctx) BN_CTX_free(ctx); ++ ++ return ret; ++} ++ ++int SM2KAP_compute_key(void *out, size_t outlen, const EC_POINT *pub_key, const EC_KEY *eckey, int server) ++{ ++ int ret = 0; ++ EC_KEY *pubkey = NULL; ++ unsigned char id[16] = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}; ++ ++ if ((pubkey = EC_KEY_new_by_curve_name(NID_sm2)) == NULL) { ++ return ret; ++ } ++ ++ if (EC_KEY_set_public_key(pubkey, pub_key) != 1) { ++ ret = 0; ++ goto out; ++ } ++ ++ ret = sm2_kap_compute_key(out, outlen, server, id, sizeof(id), id, sizeof(id), pubkey, eckey, pubkey, eckey, (EVP_MD*)EVP_sm3()); ++ ++out: ++ EC_KEY_free(pubkey); ++ return ret; ++} +diff --git a/mac.c b/mac.c +index bf051ba..2de17a0 100644 +--- a/mac.c ++++ b/mac.c +@@ -65,6 +65,7 @@ static const struct macalg macs[] = { + { "hmac-md5-96", SSH_DIGEST, SSH_DIGEST_MD5, 96, 0, 0, 0 }, + { "umac-64@openssh.com", SSH_UMAC, 0, 0, 128, 64, 0 }, + { "umac-128@openssh.com", SSH_UMAC128, 0, 0, 128, 128, 0 }, ++ { "hmac-sm3", SSH_DIGEST, SSH_DIGEST_SM3, 0, 0, 0, 0 }, + + /* Encrypt-then-MAC variants */ + { "hmac-sha1-etm@openssh.com", SSH_DIGEST, SSH_DIGEST_SHA1, 0, 0, 0, 1 }, +diff --git a/pathnames.h b/pathnames.h +index a094888..0a805ad 100644 +--- a/pathnames.h ++++ b/pathnames.h +@@ -80,6 +80,7 @@ + #define _PATH_SSH_CLIENT_ID_XMSS _PATH_SSH_USER_DIR "/id_xmss" + #define _PATH_SSH_CLIENT_ID_ECDSA_SK _PATH_SSH_USER_DIR "/id_ecdsa_sk" + #define _PATH_SSH_CLIENT_ID_ED25519_SK _PATH_SSH_USER_DIR "/id_ed25519_sk" ++#define _PATH_SSH_CLIENT_ID_SM2 _PATH_SSH_USER_DIR "/id_sm2" + + /* + * Configuration file in user's home directory. This file need not be +diff --git a/regress/agent.sh b/regress/agent.sh +index 5f10606..3ab40b4 100644 +--- a/regress/agent.sh ++++ b/regress/agent.sh +@@ -87,9 +87,18 @@ fi + for t in ${SSH_KEYTYPES}; do + trace "connect via agent using $t key" + if [ "$t" = "ssh-dss" ]; then ++ sed -i "/PubkeyAcceptedAlgorithms/d" $OBJ/ssh_proxy ++ sed -i "/PubkeyAcceptedAlgorithms/d" $OBJ/sshd_proxy + echo "PubkeyAcceptedAlgorithms +ssh-dss" >> $OBJ/ssh_proxy + echo "PubkeyAcceptedAlgorithms +ssh-dss" >> $OBJ/sshd_proxy + fi ++ if [ "$t" = "sm2" ]; then ++ sed -i "/PubkeyAcceptedAlgorithms/d" $OBJ/ssh_proxy ++ sed -i "/PubkeyAcceptedAlgorithms/d" $OBJ/sshd_proxy ++ echo "PubkeyAcceptedAlgorithms +sm2,sm2-cert" >> $OBJ/ssh_proxy ++ echo "PubkeyAcceptedAlgorithms +sm2,sm2-cert" >> $OBJ/sshd_proxy ++ fi ++ + ${SSH} -F $OBJ/ssh_proxy -i $OBJ/$t-agent.pub -oIdentitiesOnly=yes \ + somehost exit 52 + r=$? +diff --git a/regress/keytype.sh b/regress/keytype.sh +index f1c0451..2665bd6 100644 +--- a/regress/keytype.sh ++++ b/regress/keytype.sh +@@ -18,6 +18,7 @@ for i in ${SSH_KEYTYPES}; do + ecdsa-sha2-nistp521) ktypes="$ktypes ecdsa-521" ;; + sk-ssh-ed25519*) ktypes="$ktypes ed25519-sk" ;; + sk-ecdsa-sha2-nistp256*) ktypes="$ktypes ecdsa-sk" ;; ++ sm2) ktypes="$ktypes sm2-256" ;; + esac + done + +@@ -44,6 +45,7 @@ kname_to_ktype() { + rsa-*) echo rsa-sha2-512,rsa-sha2-256,ssh-rsa;; + ed25519-sk) echo sk-ssh-ed25519@openssh.com;; + ecdsa-sk) echo sk-ecdsa-sha2-nistp256@openssh.com;; ++ sm2-256) echo sm2;; + esac + } + +diff --git a/regress/knownhosts-command.sh b/regress/knownhosts-command.sh +index 8472ec8..7f56fb1 100644 +--- a/regress/knownhosts-command.sh ++++ b/regress/knownhosts-command.sh +@@ -41,6 +41,7 @@ ${SSH} -F $OBJ/ssh_proxy x true && fail "ssh connect succeeded with bad exit" + for keytype in ${SSH_HOSTKEY_TYPES} ; do + algs=$keytype + test "x$keytype" = "xssh-dss" && continue ++ test "x$keytype" = "xsm2" && continue + test "x$keytype" = "xssh-rsa" && algs=ssh-rsa,rsa-sha2-256,rsa-sha2-512 + verbose "keytype $keytype" + cat > $OBJ/knownhosts_command << _EOF +diff --git a/regress/misc/fuzz-harness/sig_fuzz.cc b/regress/misc/fuzz-harness/sig_fuzz.cc +index b32502b..f260692 100644 +--- a/regress/misc/fuzz-harness/sig_fuzz.cc ++++ b/regress/misc/fuzz-harness/sig_fuzz.cc +@@ -30,6 +30,7 @@ int LLVMFuzzerTestOneInput(const uint8_t* sig, size_t slen) + static struct sshkey *ecdsa256 = generate_or_die(KEY_ECDSA, 256); + static struct sshkey *ecdsa384 = generate_or_die(KEY_ECDSA, 384); + static struct sshkey *ecdsa521 = generate_or_die(KEY_ECDSA, 521); ++ static struct sshkey *sm2 = generate_or_die(KEY_SM2, 256); + #endif + struct sshkey_sig_details *details = NULL; + static struct sshkey *ed25519 = generate_or_die(KEY_ED25519, 0); +@@ -53,6 +54,9 @@ int LLVMFuzzerTestOneInput(const uint8_t* sig, size_t slen) + sshkey_verify(ecdsa521, sig, slen, (const u_char *)data, dlen, NULL, 0, &details); + sshkey_sig_details_free(details); + details = NULL; ++ sshkey_verify(sm2, sig, slen, (const u_char *)data, dlen, NULL, 0, &details); ++ sshkey_sig_details_free(details); ++ details = NULL; + #endif + sshkey_verify(ed25519, sig, slen, (const u_char *)data, dlen, NULL, 0, &details); + sshkey_sig_details_free(details); +diff --git a/regress/unittests/kex/test_kex.c b/regress/unittests/kex/test_kex.c +index c26761e..d335b29 100644 +--- a/regress/unittests/kex/test_kex.c ++++ b/regress/unittests/kex/test_kex.c +@@ -151,6 +151,7 @@ do_kex_with_key(char *kex, int keytype, int bits) + #endif /* OPENSSL_HAS_ECC */ + #endif /* WITH_OPENSSL */ + server2->kex->kex[KEX_C25519_SHA256] = kex_gen_server; ++ server2->kex->kex[KEX_SM2_SM3] = kex_gen_server; + server2->kex->kex[KEX_KEM_SNTRUP761X25519_SHA512] = kex_gen_server; + server2->kex->load_host_public_key = server->kex->load_host_public_key; + server2->kex->load_host_private_key = server->kex->load_host_private_key; +@@ -185,6 +186,7 @@ do_kex(char *kex) + #endif /* OPENSSL_HAS_ECC */ + #endif /* WITH_OPENSSL */ + do_kex_with_key(kex, KEY_ED25519, 256); ++ do_kex_with_key(kex, KEY_SM2, 256); + } + + void +@@ -201,6 +203,7 @@ kex_tests(void) + do_kex("diffie-hellman-group-exchange-sha1"); + do_kex("diffie-hellman-group14-sha1"); + do_kex("diffie-hellman-group1-sha1"); ++ do_kex("sm2-sm3"); + # ifdef USE_SNTRUP761X25519 + do_kex("sntrup761x25519-sha512@openssh.com"); + # endif /* USE_SNTRUP761X25519 */ +diff --git a/ssh-ecdsa.c b/ssh-ecdsa.c +index b705157..5445ab5 100644 +--- a/ssh-ecdsa.c ++++ b/ssh-ecdsa.c +@@ -256,7 +256,8 @@ ssh_ecdsa_sign(struct sshkey *key, + *sigp = NULL; + + if (key == NULL || key->ecdsa == NULL || +- sshkey_type_plain(key->type) != KEY_ECDSA) ++ (sshkey_type_plain(key->type) != KEY_ECDSA && ++ sshkey_type_plain(key->type) != KEY_SM2)) + return SSH_ERR_INVALID_ARGUMENT; + + if ((hash_alg = sshkey_ec_nid_to_hash_alg(key->ecdsa_nid)) == -1) +@@ -332,7 +333,8 @@ ssh_ecdsa_verify(const struct sshkey *key, + unsigned char *sigb = NULL, *psig = NULL; + + if (key == NULL || key->ecdsa == NULL || +- sshkey_type_plain(key->type) != KEY_ECDSA || ++ (sshkey_type_plain(key->type) != KEY_ECDSA && ++ sshkey_type_plain(key->type) != KEY_SM2) || + sig == NULL || siglen == 0) + return SSH_ERR_INVALID_ARGUMENT; + +diff --git a/ssh-keygen.c b/ssh-keygen.c +index 0bff209..46f4998 100644 +--- a/ssh-keygen.c ++++ b/ssh-keygen.c +@@ -193,6 +193,7 @@ type_bits_valid(int type, const char *name, u_int32_t *bitsp) + *bitsp = DEFAULT_BITS_DSA; + break; + case KEY_ECDSA: ++ case KEY_SM2: + if (name != NULL && + (nid = sshkey_ecdsa_nid_from_name(name)) > 0) + *bitsp = sshkey_curve_nid_to_bits(nid); +@@ -219,6 +220,10 @@ type_bits_valid(int type, const char *name, u_int32_t *bitsp) + fatal("Invalid RSA key length: maximum is %d bits", + OPENSSL_RSA_MAX_MODULUS_BITS); + break; ++ case KEY_SM2: ++ if (*bitsp != 256) ++ fatal("Invalid SM2 key length: must be 256 bits"); ++ break; + case KEY_ECDSA: + if (sshkey_ecdsa_bits_to_nid(*bitsp) == -1) + #ifdef OPENSSL_HAS_NISTP521 +@@ -275,6 +280,9 @@ ask_filename(struct passwd *pw, const char *prompt) + case KEY_ECDSA: + name = _PATH_SSH_CLIENT_ID_ECDSA; + break; ++ case KEY_SM2: ++ name = _PATH_SSH_CLIENT_ID_SM2; ++ break; + case KEY_ECDSA_SK_CERT: + case KEY_ECDSA_SK: + name = _PATH_SSH_CLIENT_ID_ECDSA_SK; +@@ -386,6 +394,7 @@ do_convert_to_pkcs8(struct sshkey *k) + break; + #ifdef OPENSSL_HAS_ECC + case KEY_ECDSA: ++ case KEY_SM2: + if (!PEM_write_EC_PUBKEY(stdout, k->ecdsa)) + fatal("PEM_write_EC_PUBKEY failed"); + break; +@@ -410,6 +419,7 @@ do_convert_to_pem(struct sshkey *k) + break; + #ifdef OPENSSL_HAS_ECC + case KEY_ECDSA: ++ case KEY_SM2: + if (!PEM_write_EC_PUBKEY(stdout, k->ecdsa)) + fatal("PEM_write_EC_PUBKEY failed"); + break; +@@ -3280,7 +3290,7 @@ usage(void) + fprintf(stderr, + "usage: ssh-keygen [-q] [-a rounds] [-b bits] [-C comment] [-f output_keyfile]\n" + " [-m format] [-N new_passphrase] [-O option]\n" +- " [-t dsa | ecdsa | ecdsa-sk | ed25519 | ed25519-sk | rsa]\n" ++ " [-t dsa | ecdsa | ecdsa-sk | ed25519 | ed25519-sk | rsa | sm2]\n" + " [-w provider] [-Z cipher]\n" + " ssh-keygen -p [-a rounds] [-f keyfile] [-m format] [-N new_passphrase]\n" + " [-P old_passphrase] [-Z cipher]\n" +diff --git a/ssh-keyscan.c b/ssh-keyscan.c +index 245c73d..b402a21 100644 +--- a/ssh-keyscan.c ++++ b/ssh-keyscan.c +@@ -68,9 +68,10 @@ int ssh_port = SSH_DEFAULT_PORT; + #define KT_XMSS (1<<4) + #define KT_ECDSA_SK (1<<5) + #define KT_ED25519_SK (1<<6) ++#define KT_SM2 (1<<7) + + #define KT_MIN KT_DSA +-#define KT_MAX KT_ED25519_SK ++#define KT_MAX KT_SM2 + + int get_cert = 0; + int get_keytypes = KT_RSA|KT_ECDSA|KT_ED25519|KT_ECDSA_SK|KT_ED25519_SK; +@@ -267,6 +268,11 @@ keygrab_ssh2(con *c) + "ecdsa-sha2-nistp384," + "ecdsa-sha2-nistp521"; + break; ++ case KT_SM2: ++ myproposal[PROPOSAL_SERVER_HOST_KEY_ALGS] = get_cert ? ++ "sm2-cert" : ++ "sm2"; ++ break; + case KT_ECDSA_SK: + myproposal[PROPOSAL_SERVER_HOST_KEY_ALGS] = get_cert ? + "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com" : +@@ -296,6 +302,7 @@ keygrab_ssh2(con *c) + c->c_ssh->kex->kex[KEX_DH_GEX_SHA256] = kexgex_client; + # ifdef OPENSSL_HAS_ECC + c->c_ssh->kex->kex[KEX_ECDH_SHA2] = kex_gen_client; ++ c->c_ssh->kex->kex[KEX_SM2_SM3] = kex_gen_client; + # endif + #endif + c->c_ssh->kex->kex[KEX_C25519_SHA256] = kex_gen_client; +@@ -789,6 +796,9 @@ main(int argc, char **argv) + case KEY_ECDSA: + get_keytypes |= KT_ECDSA; + break; ++ case KEY_SM2: ++ get_keytypes |= KT_SM2; ++ break; + case KEY_RSA: + get_keytypes |= KT_RSA; + break; +diff --git a/ssh-sm2.c b/ssh-sm2.c +new file mode 100644 +index 0000000..75e9731 +--- /dev/null ++++ b/ssh-sm2.c +@@ -0,0 +1,381 @@ ++#include "includes.h" ++#include <sys/types.h> ++#include <openssl/bn.h> ++#include <openssl/ecdsa.h> ++#include <openssl/evp.h> ++ ++#include <string.h> ++#include "sshbuf.h" ++#include "ssherr.h" ++#include "digest.h" ++#include "sshkey.h" ++ ++#include "openbsd-compat/openssl-compat.h" ++ ++/* Reuse some ECDSA internals */ ++extern struct sshkey_impl_funcs sshkey_ecdsa_funcs; ++ ++const unsigned char *sm2_id = (const unsigned char *)"1234567812345678"; ++ ++static void ++ssh_sm2_cleanup(struct sshkey *k) ++{ ++ EC_KEY_free(k->ecdsa); ++ k->ecdsa = NULL; ++} ++ ++static int ++ssh_sm2_equal(const struct sshkey *a, const struct sshkey *b) ++{ ++ if (!sshkey_ecdsa_funcs.equal(a, b)) ++ return 0; ++ return 1; ++} ++ ++static int ++ssh_sm2_serialize_public(const struct sshkey *key, struct sshbuf *b, ++ enum sshkey_serialize_rep opts) ++{ ++ int r; ++ ++ if ((r = sshkey_ecdsa_funcs.serialize_public(key, b, opts)) != 0) ++ return r; ++ ++ return 0; ++} ++ ++static int ++ssh_sm2_deserialize_public(const char *ktype, struct sshbuf *b, ++ struct sshkey *key) ++{ ++ int r; ++ ++ if ((r = sshkey_ecdsa_funcs.deserialize_public(ktype, b, key)) != 0) ++ return r; ++ return 0; ++} ++ ++static int ++ssh_sm2_serialize_private(const struct sshkey *key, struct sshbuf *b, ++ enum sshkey_serialize_rep opts) ++{ ++ int r; ++ ++ if ((r = sshkey_ecdsa_funcs.serialize_private(key, b, opts)) != 0) ++ return r; ++ ++ return 0; ++} ++ ++static int ++ssh_sm2_deserialize_private(const char *ktype, struct sshbuf *b, ++ struct sshkey *key) ++{ ++ int r; ++ ++ if ((r = sshkey_ecdsa_funcs.deserialize_private(ktype, b, key)) != 0) ++ return r; ++ ++ return 0; ++} ++ ++static int ++ssh_sm2_generate(struct sshkey *k, int bits) ++{ ++ EC_KEY *private; ++ ++ k->ecdsa_nid = NID_sm2; ++ if ((private = EC_KEY_new_by_curve_name(k->ecdsa_nid)) == NULL) ++ return SSH_ERR_ALLOC_FAIL; ++ if (EC_KEY_generate_key(private) != 1) { ++ EC_KEY_free(private); ++ return SSH_ERR_LIBCRYPTO_ERROR; ++ } ++ EC_KEY_set_asn1_flag(private, OPENSSL_EC_NAMED_CURVE); ++ k->ecdsa = private; ++ return 0; ++} ++ ++static int ++ssh_sm2_copy_public(const struct sshkey *from, struct sshkey *to) ++{ ++ int r; ++ ++ if ((r = sshkey_ecdsa_funcs.copy_public(from, to)) != 0) ++ return r; ++ return 0; ++} ++ ++static int ++sm2_get_sig(EVP_PKEY *pkey, const u_char *data, ++ size_t datalen, u_char *sig, size_t *slen) ++{ ++ EVP_PKEY_CTX *pctx = NULL; ++ EVP_MD_CTX *mctx = NULL; ++ int ret = SSH_ERR_INTERNAL_ERROR; ++ ++ if ((pctx = EVP_PKEY_CTX_new(pkey, NULL)) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ if ((mctx = EVP_MD_CTX_new()) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ if (EVP_PKEY_CTX_set1_id(pctx, sm2_id, 16) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ EVP_MD_CTX_set_pkey_ctx(mctx, pctx); ++ ++ if ((EVP_DigestSignInit(mctx, NULL, EVP_sm3(), NULL, pkey)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ if ((EVP_DigestSignUpdate(mctx, data, datalen)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ if ((EVP_DigestSignFinal(mctx, sig, slen)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ret = 0; ++ ++out: ++ EVP_PKEY_CTX_free(pctx); ++ EVP_MD_CTX_free(mctx); ++ return ret; ++} ++ ++static int ++ssh_sm2_sign(struct sshkey *key, ++ u_char **sigp, size_t *lenp, ++ const u_char *data, size_t datalen, ++ const char *alg, const char *sk_provider, const char *sk_pin, u_int compat) ++{ ++ u_char *sig = NULL; ++ size_t slen = 0; ++ int pkey_len = 0; ++ int r = 0; ++ int len = 0; ++ EVP_PKEY *key_sm2 = NULL; ++ struct sshbuf *b = NULL; ++ int ret = SSH_ERR_INTERNAL_ERROR; ++ ++ if (lenp != NULL) ++ *lenp = 0; ++ if (sigp != NULL) ++ *sigp = NULL; ++ ++ if (key == NULL || key->ecdsa == NULL || ++ sshkey_type_plain(key->type) != KEY_SM2) ++ return SSH_ERR_INVALID_ARGUMENT; ++ ++ if ((key_sm2 = EVP_PKEY_new()) == NULL) { ++ return SSH_ERR_ALLOC_FAIL; ++ } ++ ++ if ((EVP_PKEY_set1_EC_KEY(key_sm2, key->ecdsa)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ if ((pkey_len = EVP_PKEY_size(key_sm2)) == 0) { ++ ret = SSH_ERR_INVALID_ARGUMENT; ++ goto out; ++ } ++ ++ slen = pkey_len; ++ ++ if ((sig = OPENSSL_malloc(pkey_len)) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ ++ if (ret = sm2_get_sig(key_sm2, data, datalen, sig, &slen)) { ++ goto out; ++ } ++ ++ if ((b = sshbuf_new()) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ ++ if ((r = sshbuf_put_cstring(b, "sm2")) != 0 || ++ (r = sshbuf_put_string(b, sig, slen)) != 0) ++ goto out; ++ len = sshbuf_len(b); ++ if (sigp != NULL) { ++ if ((*sigp = malloc(len)) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ memcpy(*sigp, sshbuf_ptr(b), len); ++ } ++ if (lenp != NULL) ++ *lenp = len; ++ ret = 0; ++ ++out: ++ EVP_PKEY_free(key_sm2); ++ if (sig != NULL) { ++ explicit_bzero(sig, slen); ++ OPENSSL_free(sig); ++ } ++ sshbuf_free(b); ++ return ret; ++} ++ ++static int ++sm2_verify_sig(EVP_PKEY *pkey, const u_char *data, ++ size_t datalen, const u_char *sig, size_t slen) ++{ ++ EVP_PKEY_CTX *pctx = NULL; ++ EVP_MD_CTX *mctx = NULL; ++ int ret = SSH_ERR_INTERNAL_ERROR; ++ ++ if ((pctx = EVP_PKEY_CTX_new(pkey, NULL)) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ ++ if ((mctx = EVP_MD_CTX_new()) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ ++ if (EVP_PKEY_CTX_set1_id(pctx, sm2_id, 16) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ EVP_MD_CTX_set_pkey_ctx(mctx, pctx); ++ ++ if ((EVP_DigestVerifyInit(mctx, NULL, EVP_sm3(), NULL, pkey)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ if ((EVP_DigestVerifyUpdate(mctx, data, datalen)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ if ((EVP_DigestVerifyFinal(mctx, sig, slen)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ ret = 0; ++out: ++ EVP_PKEY_CTX_free(pctx); ++ EVP_MD_CTX_free(mctx); ++ return ret; ++} ++ ++static int ++ssh_sm2_verify(const struct sshkey *key, ++ const u_char *signature, size_t signaturelen, ++ const u_char *data, size_t datalen, const char *alg, u_int compat, ++ struct sshkey_sig_details **detailsp) ++{ ++ const u_char *sig = NULL; ++ char *ktype = NULL; ++ size_t slen = 0; ++ int pkey_len = 0; ++ int r = 0; ++ int len = 0; ++ EVP_PKEY *key_sm2 = NULL; ++ struct sshbuf *b = NULL; ++ int ret = SSH_ERR_INTERNAL_ERROR; ++ ++ if (key == NULL || ++ sshkey_type_plain(key->type) != KEY_SM2 || ++ signature == NULL || signaturelen == 0) ++ return SSH_ERR_INVALID_ARGUMENT; ++ ++ if ((b = sshbuf_from(signature, signaturelen)) == NULL) ++ return SSH_ERR_ALLOC_FAIL; ++ ++ if ((r = sshbuf_get_cstring(b, &ktype, NULL)) != 0 || ++ (r = sshbuf_get_string_direct(b, &sig, &slen)) != 0) ++ goto out; ++ ++ if (strcmp("sm2", ktype) != 0) { ++ ret = SSH_ERR_KEY_TYPE_MISMATCH; ++ goto out; ++ } ++ ++ if (sshbuf_len(b) != 0) { ++ ret = SSH_ERR_UNEXPECTED_TRAILING_DATA; ++ goto out; ++ } ++ ++ if ((key_sm2 = EVP_PKEY_new()) == NULL) { ++ ret = SSH_ERR_ALLOC_FAIL; ++ goto out; ++ } ++ ++ if ((EVP_PKEY_set1_EC_KEY(key_sm2, key->ecdsa)) != 1) { ++ ret = SSH_ERR_INTERNAL_ERROR; ++ goto out; ++ } ++ ++ if ((pkey_len = EVP_PKEY_size(key_sm2)) == 0) { ++ ret = SSH_ERR_INVALID_ARGUMENT; ++ goto out; ++ } ++ ++ if (ret = sm2_verify_sig(key_sm2, data, datalen, sig, slen)) { ++ goto out; ++ } ++ ++ ret = 0; ++out: ++ EVP_PKEY_free(key_sm2); ++ sshbuf_free(b); ++ free(ktype); ++ return ret; ++} ++ ++static const struct sshkey_impl_funcs sshkey_sm2_funcs = { ++ /* .size = */ NULL, ++ /* .alloc = */ NULL, ++ /* .cleanup = */ ssh_sm2_cleanup, ++ /* .equal = */ ssh_sm2_equal, ++ /* .ssh_serialize_public = */ ssh_sm2_serialize_public, ++ /* .ssh_deserialize_public = */ ssh_sm2_deserialize_public, ++ /* .ssh_serialize_private = */ ssh_sm2_serialize_private, ++ /* .ssh_deserialize_private = */ssh_sm2_deserialize_private, ++ /* .generate = */ ssh_sm2_generate, ++ /* .copy_public = */ ssh_sm2_copy_public, ++ /* .sign = */ ssh_sm2_sign, ++ /* .verify = */ ssh_sm2_verify, ++}; ++ ++const struct sshkey_impl sshkey_sm2_impl = { ++ /* .name = */ "sm2", ++ /* .shortname = */ "SM2", ++ /* .sigalg = */ NULL, ++ /* .type = */ KEY_SM2, ++ /* .nid = */ NID_sm2, ++ /* .cert = */ 0, ++ /* .sigonly = */ 0, ++ /* .keybits = */ 256, ++ /* .funcs = */ &sshkey_sm2_funcs, ++}; ++ ++const struct sshkey_impl sshkey_sm2_cert_impl = { ++ /* .name = */ "sm2-cert", ++ /* .shortname = */ "SM2-CERT", ++ /* .sigalg = */ NULL, ++ /* .type = */ KEY_SM2_CERT, ++ /* .nid = */ NID_sm2, ++ /* .cert = */ 1, ++ /* .sigonly = */ 0, ++ /* .keybits = */ 256, ++ /* .funcs = */ &sshkey_sm2_funcs, ++}; +diff --git a/ssh_api.c b/ssh_api.c +index d3c6617..adc2598 100644 +--- a/ssh_api.c ++++ b/ssh_api.c +@@ -115,6 +115,7 @@ ssh_init(struct ssh **sshp, int is_server, struct kex_params *kex_params) + ssh->kex->kex[KEX_DH_GEX_SHA256] = kexgex_server; + # ifdef OPENSSL_HAS_ECC + ssh->kex->kex[KEX_ECDH_SHA2] = kex_gen_server; ++ ssh->kex->kex[KEX_SM2_SM3] = kex_gen_server; + # endif + #endif /* WITH_OPENSSL */ + ssh->kex->kex[KEX_C25519_SHA256] = kex_gen_server; +@@ -133,6 +134,7 @@ ssh_init(struct ssh **sshp, int is_server, struct kex_params *kex_params) + ssh->kex->kex[KEX_DH_GEX_SHA256] = kexgex_client; + # ifdef OPENSSL_HAS_ECC + ssh->kex->kex[KEX_ECDH_SHA2] = kex_gen_client; ++ ssh->kex->kex[KEX_SM2_SM3] = kex_gen_client; + # endif + #endif /* WITH_OPENSSL */ + ssh->kex->kex[KEX_C25519_SHA256] = kex_gen_client; +diff --git a/sshconnect2.c b/sshconnect2.c +index 3acfdb6..3fbff57 100644 +--- a/sshconnect2.c ++++ b/sshconnect2.c +@@ -326,6 +326,7 @@ ssh_kex2(struct ssh *ssh, char *host, struct sockaddr *hostaddr, u_short port, + ssh->kex->kex[KEX_DH_GEX_SHA256] = kexgex_client; + # ifdef OPENSSL_HAS_ECC + ssh->kex->kex[KEX_ECDH_SHA2] = kex_gen_client; ++ ssh->kex->kex[KEX_SM2_SM3] = kex_gen_client; + # endif + # ifdef GSSAPI + if (options.gss_keyex) { +diff --git a/sshd.c b/sshd.c +index f366457..52c66ed 100644 +--- a/sshd.c ++++ b/sshd.c +@@ -695,6 +695,7 @@ list_hostkey_types(void) + /* FALLTHROUGH */ + case KEY_DSA: + case KEY_ECDSA: ++ case KEY_SM2: + case KEY_ED25519: + case KEY_ECDSA_SK: + case KEY_ED25519_SK: +@@ -716,6 +717,7 @@ list_hostkey_types(void) + /* FALLTHROUGH */ + case KEY_DSA_CERT: + case KEY_ECDSA_CERT: ++ case KEY_SM2_CERT: + case KEY_ED25519_CERT: + case KEY_ECDSA_SK_CERT: + case KEY_ED25519_SK_CERT: +@@ -742,6 +744,7 @@ get_hostkey_by_type(int type, int nid, int need_private, struct ssh *ssh) + case KEY_RSA_CERT: + case KEY_DSA_CERT: + case KEY_ECDSA_CERT: ++ case KEY_SM2_CERT: + case KEY_ED25519_CERT: + case KEY_ECDSA_SK_CERT: + case KEY_ED25519_SK_CERT: +@@ -758,8 +761,10 @@ get_hostkey_by_type(int type, int nid, int need_private, struct ssh *ssh) + continue; + switch (type) { + case KEY_ECDSA: ++ case KEY_SM2: + case KEY_ECDSA_SK: + case KEY_ECDSA_CERT: ++ case KEY_SM2_CERT: + case KEY_ECDSA_SK_CERT: + if (key->ecdsa_nid != nid) + continue; +@@ -2012,6 +2017,7 @@ main(int ac, char **av) + case KEY_RSA: + case KEY_DSA: + case KEY_ECDSA: ++ case KEY_SM2: + case KEY_ED25519: + case KEY_ECDSA_SK: + case KEY_ED25519_SK: +@@ -2573,6 +2579,7 @@ do_ssh2_kex(struct ssh *ssh) + kex->kex[KEX_DH_GEX_SHA256] = kexgex_server; + # ifdef OPENSSL_HAS_ECC + kex->kex[KEX_ECDH_SHA2] = kex_gen_server; ++ kex->kex[KEX_SM2_SM3] = kex_gen_server; + # endif + # ifdef GSSAPI + if (options.gss_keyex) { +diff --git a/sshkey.c b/sshkey.c +index 1735159..1aee244 100644 +--- a/sshkey.c ++++ b/sshkey.c +@@ -130,6 +130,8 @@ extern const struct sshkey_impl sshkey_dsa_cert_impl; + extern const struct sshkey_impl sshkey_xmss_impl; + extern const struct sshkey_impl sshkey_xmss_cert_impl; + #endif ++extern const struct sshkey_impl sshkey_sm2_impl; ++extern const struct sshkey_impl sshkey_sm2_cert_impl; + + static int ssh_gss_equal(const struct sshkey *, const struct sshkey *) + { +@@ -237,6 +239,8 @@ const struct sshkey_impl * const keyimpls[] = { + &sshkey_xmss_cert_impl, + #endif + &sshkey_gss_kex_impl, ++ &sshkey_sm2_impl, ++ &sshkey_sm2_cert_impl, + NULL + }; + +@@ -340,6 +344,8 @@ key_type_is_ecdsa_variant(int type) + case KEY_ECDSA_CERT: + case KEY_ECDSA_SK: + case KEY_ECDSA_SK_CERT: ++ case KEY_SM2: ++ case KEY_SM2_CERT: + return 1; + } + return 0; +@@ -548,6 +554,8 @@ sshkey_type_plain(int type) + return KEY_ED25519_SK; + case KEY_XMSS_CERT: + return KEY_XMSS; ++ case KEY_SM2_CERT: ++ return KEY_SM2; + default: + return type; + } +@@ -564,6 +572,8 @@ sshkey_type_certified(int type) + return KEY_DSA_CERT; + case KEY_ECDSA: + return KEY_ECDSA_CERT; ++ case KEY_SM2: ++ return KEY_SM2_CERT; + case KEY_ECDSA_SK: + return KEY_ECDSA_SK_CERT; + case KEY_ED25519: +@@ -670,6 +680,8 @@ sshkey_curve_name_to_nid(const char *name) + else if (strcmp(name, "nistp521") == 0) + return NID_secp521r1; + # endif /* OPENSSL_HAS_NISTP521 */ ++ else if (strcmp(name, "sm2") == 0) ++ return NID_sm2; + else + return -1; + } +@@ -686,6 +698,8 @@ sshkey_curve_nid_to_bits(int nid) + case NID_secp521r1: + return 521; + # endif /* OPENSSL_HAS_NISTP521 */ ++ case NID_sm2: ++ return 256; + default: + return 0; + } +@@ -720,6 +734,8 @@ sshkey_curve_nid_to_name(int nid) + case NID_secp521r1: + return "nistp521"; + # endif /* OPENSSL_HAS_NISTP521 */ ++ case NID_sm2: ++ return "sm2"; + default: + return NULL; + } +@@ -3424,6 +3440,7 @@ sshkey_private_to_blob_pem_pkcs8(struct sshkey *key, struct sshbuf *buf, + break; + #ifdef OPENSSL_HAS_ECC + case KEY_ECDSA: ++ case KEY_SM2: + if (format == SSHKEY_PRIVATE_PEM) { + success = PEM_write_bio_ECPrivateKey(bio, key->ecdsa, + cipher, passphrase, len, NULL, NULL); +@@ -3485,6 +3502,7 @@ sshkey_private_to_fileblob(struct sshkey *key, struct sshbuf *blob, + #ifdef WITH_OPENSSL + case KEY_DSA: + case KEY_ECDSA: ++ case KEY_SM2: + case KEY_RSA: + break; /* see below */ + #endif /* WITH_OPENSSL */ +@@ -3665,6 +3683,9 @@ sshkey_parse_private_pem_fileblob(struct sshbuf *blob, int type, + prv->ecdsa = EVP_PKEY_get1_EC_KEY(pk); + prv->type = KEY_ECDSA; + prv->ecdsa_nid = sshkey_ecdsa_key_to_nid(prv->ecdsa); ++ if (prv->ecdsa_nid == NID_sm2) { ++ prv->type = KEY_SM2; ++ } + if (prv->ecdsa_nid == -1 || + sshkey_curve_nid_to_name(prv->ecdsa_nid) == NULL || + sshkey_ec_validate_public(EC_KEY_get0_group(prv->ecdsa), +diff --git a/sshkey.h b/sshkey.h +index 8d662d1..c8d2662 100644 +--- a/sshkey.h ++++ b/sshkey.h +@@ -68,6 +68,8 @@ enum sshkey_types { + KEY_DSA_CERT, + KEY_ECDSA_CERT, + KEY_ED25519_CERT, ++ KEY_SM2, ++ KEY_SM2_CERT, + KEY_XMSS, + KEY_XMSS_CERT, + KEY_ECDSA_SK, +-- +2.23.0 + |