baseid_crypto/
jwk.rs

1//! JSON Web Key (JWK) serialization and deserialization.
2
3use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
4use baseid_core::error::CryptoError;
5use baseid_core::types::KeyType;
6use p256::elliptic_curve::sec1::ToEncodedPoint;
7use serde::{Deserialize, Serialize};
8
9use crate::key::{KeyPair, PublicKey};
10
11/// A JSON Web Key as defined in RFC 7517.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Jwk {
14    pub kty: String,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub crv: Option<String>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub x: Option<String>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub y: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub d: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub kid: Option<String>,
25    #[serde(rename = "use", skip_serializing_if = "Option::is_none")]
26    pub key_use: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub alg: Option<String>,
29}
30
31impl Jwk {
32    /// Create a JWK from a public key.
33    pub fn from_public_key(public_key: &PublicKey) -> baseid_core::Result<Self> {
34        match public_key.key_type {
35            KeyType::Bls12381G2 => Err(CryptoError::UnsupportedAlgorithm.into()),
36            KeyType::Ed25519 => Ok(Jwk {
37                kty: "OKP".to_string(),
38                crv: Some("Ed25519".to_string()),
39                x: Some(URL_SAFE_NO_PAD.encode(&public_key.bytes)),
40                y: None,
41                d: None,
42                kid: None,
43                key_use: None,
44                alg: None,
45            }),
46            KeyType::P256 => {
47                let (x, y) = decompress_ec_point::<p256::NistP256>(&public_key.bytes, 32)?;
48                Ok(Jwk {
49                    kty: "EC".to_string(),
50                    crv: Some("P-256".to_string()),
51                    x: Some(URL_SAFE_NO_PAD.encode(&x)),
52                    y: Some(URL_SAFE_NO_PAD.encode(&y)),
53                    d: None,
54                    kid: None,
55                    key_use: None,
56                    alg: None,
57                })
58            }
59            KeyType::P384 => {
60                let (x, y) = decompress_ec_point::<p384::NistP384>(&public_key.bytes, 48)?;
61                Ok(Jwk {
62                    kty: "EC".to_string(),
63                    crv: Some("P-384".to_string()),
64                    x: Some(URL_SAFE_NO_PAD.encode(&x)),
65                    y: Some(URL_SAFE_NO_PAD.encode(&y)),
66                    d: None,
67                    kid: None,
68                    key_use: None,
69                    alg: None,
70                })
71            }
72            KeyType::Secp256k1 => {
73                let (x, y) = decompress_ec_point::<k256::Secp256k1>(&public_key.bytes, 32)?;
74                Ok(Jwk {
75                    kty: "EC".to_string(),
76                    crv: Some("secp256k1".to_string()),
77                    x: Some(URL_SAFE_NO_PAD.encode(&x)),
78                    y: Some(URL_SAFE_NO_PAD.encode(&y)),
79                    d: None,
80                    kid: None,
81                    key_use: None,
82                    alg: None,
83                })
84            }
85        }
86    }
87
88    /// Convert this JWK to a `PublicKey`.
89    pub fn to_public_key(&self) -> baseid_core::Result<PublicKey> {
90        let crv = self.crv.as_deref().ok_or(CryptoError::InvalidKeyMaterial)?;
91        let x_b64 = self.x.as_deref().ok_or(CryptoError::InvalidKeyMaterial)?;
92        let x = URL_SAFE_NO_PAD
93            .decode(x_b64)
94            .map_err(|_| CryptoError::InvalidKeyMaterial)?;
95
96        match (self.kty.as_str(), crv) {
97            ("OKP", "Ed25519") => PublicKey::from_bytes(KeyType::Ed25519, &x),
98            ("EC", "P-256") => {
99                let y = decode_y(&self.y)?;
100                let bytes = encode_ec_uncompressed(&x, &y);
101                // Parse as uncompressed then re-encode as compressed
102                let point = p256::EncodedPoint::from_bytes(&bytes)
103                    .map_err(|_| CryptoError::InvalidKeyMaterial)?;
104                use p256::elliptic_curve::sec1::FromEncodedPoint;
105                let pk = p256::PublicKey::from_encoded_point(&point);
106                if pk.is_none().into() {
107                    return Err(CryptoError::InvalidKeyMaterial.into());
108                }
109                let pk = pk.unwrap();
110                let compressed = pk.to_encoded_point(true);
111                PublicKey::from_bytes(KeyType::P256, compressed.as_bytes())
112            }
113            ("EC", "P-384") => {
114                let y = decode_y(&self.y)?;
115                let bytes = encode_ec_uncompressed(&x, &y);
116                let point = p384::EncodedPoint::from_bytes(&bytes)
117                    .map_err(|_| CryptoError::InvalidKeyMaterial)?;
118                use p384::elliptic_curve::sec1::FromEncodedPoint;
119                let pk = p384::PublicKey::from_encoded_point(&point);
120                if pk.is_none().into() {
121                    return Err(CryptoError::InvalidKeyMaterial.into());
122                }
123                let pk = pk.unwrap();
124                let compressed = pk.to_encoded_point(true);
125                PublicKey::from_bytes(KeyType::P384, compressed.as_bytes())
126            }
127            ("EC", "secp256k1") => {
128                let y = decode_y(&self.y)?;
129                let bytes = encode_ec_uncompressed(&x, &y);
130                let point = k256::EncodedPoint::from_bytes(&bytes)
131                    .map_err(|_| CryptoError::InvalidKeyMaterial)?;
132                use k256::elliptic_curve::sec1::FromEncodedPoint;
133                let pk = k256::PublicKey::from_encoded_point(&point);
134                if pk.is_none().into() {
135                    return Err(CryptoError::InvalidKeyMaterial.into());
136                }
137                let pk = pk.unwrap();
138                let compressed = pk.to_encoded_point(true);
139                PublicKey::from_bytes(KeyType::Secp256k1, compressed.as_bytes())
140            }
141            _ => Err(CryptoError::UnsupportedAlgorithm.into()),
142        }
143    }
144
145    /// Create a JWK from a key pair (includes the private key as `d`).
146    pub fn from_key_pair(key_pair: &KeyPair) -> baseid_core::Result<Self> {
147        let mut jwk = Self::from_public_key(&key_pair.public)?;
148        jwk.d = Some(URL_SAFE_NO_PAD.encode(key_pair.secret_bytes()));
149        Ok(jwk)
150    }
151
152    /// Convert this JWK to a `KeyPair`. Requires the `d` field to be present.
153    pub fn to_key_pair(&self) -> baseid_core::Result<KeyPair> {
154        let d_b64 = self.d.as_deref().ok_or(CryptoError::InvalidKeyMaterial)?;
155        let secret_bytes = URL_SAFE_NO_PAD
156            .decode(d_b64)
157            .map_err(|_| CryptoError::InvalidKeyMaterial)?;
158
159        let crv = self.crv.as_deref().ok_or(CryptoError::InvalidKeyMaterial)?;
160        let key_type = match (self.kty.as_str(), crv) {
161            ("OKP", "Ed25519") => KeyType::Ed25519,
162            ("EC", "P-256") => KeyType::P256,
163            ("EC", "P-384") => KeyType::P384,
164            ("EC", "secp256k1") => KeyType::Secp256k1,
165            _ => return Err(CryptoError::UnsupportedAlgorithm.into()),
166        };
167
168        KeyPair::from_bytes(key_type, &secret_bytes)
169    }
170}
171
172/// Decompress an EC point and return (x, y) coordinate byte vectors.
173fn decompress_ec_point<C>(
174    compressed: &[u8],
175    coord_len: usize,
176) -> baseid_core::Result<(Vec<u8>, Vec<u8>)>
177where
178    C: p256::elliptic_curve::Curve + p256::elliptic_curve::CurveArithmetic,
179    <C as p256::elliptic_curve::CurveArithmetic>::AffinePoint:
180        p256::elliptic_curve::sec1::FromEncodedPoint<C>
181            + p256::elliptic_curve::sec1::ToEncodedPoint<C>,
182    p256::elliptic_curve::FieldBytesSize<C>: p256::elliptic_curve::sec1::ModulusSize,
183{
184    let point = p256::elliptic_curve::sec1::EncodedPoint::<C>::from_bytes(compressed)
185        .map_err(|_| CryptoError::InvalidKeyMaterial)?;
186
187    // If compressed, decompress via the curve's AffinePoint
188    let uncompressed = if point.is_compressed() {
189        use p256::elliptic_curve::sec1::FromEncodedPoint;
190        let affine =
191            <C as p256::elliptic_curve::CurveArithmetic>::AffinePoint::from_encoded_point(&point);
192        if affine.is_none().into() {
193            return Err(CryptoError::InvalidKeyMaterial.into());
194        }
195        use p256::elliptic_curve::sec1::ToEncodedPoint;
196        affine.unwrap().to_encoded_point(false)
197    } else {
198        point
199    };
200
201    let x = uncompressed.x().ok_or(CryptoError::InvalidKeyMaterial)?;
202    let y = uncompressed.y().ok_or(CryptoError::InvalidKeyMaterial)?;
203
204    // Pad coordinates to coord_len
205    let mut x_vec: Vec<u8> = x.iter().copied().collect();
206    let mut y_vec: Vec<u8> = y.iter().copied().collect();
207    while x_vec.len() < coord_len {
208        x_vec.insert(0, 0);
209    }
210    while y_vec.len() < coord_len {
211        y_vec.insert(0, 0);
212    }
213
214    Ok((x_vec, y_vec))
215}
216
217/// Decode the `y` coordinate from a JWK.
218fn decode_y(y: &Option<String>) -> baseid_core::Result<Vec<u8>> {
219    let y_b64 = y.as_deref().ok_or(CryptoError::InvalidKeyMaterial)?;
220    URL_SAFE_NO_PAD
221        .decode(y_b64)
222        .map_err(|_| CryptoError::InvalidKeyMaterial.into())
223}
224
225/// Encode x/y coordinates as an uncompressed SEC1 point (0x04 || x || y).
226fn encode_ec_uncompressed(x: &[u8], y: &[u8]) -> Vec<u8> {
227    let mut out = Vec::with_capacity(1 + x.len() + y.len());
228    out.push(0x04);
229    out.extend_from_slice(x);
230    out.extend_from_slice(y);
231    out
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::key::KeyPair;
238
239    #[test]
240    fn public_key_roundtrip_ed25519() {
241        let kp = KeyPair::generate(KeyType::Ed25519).unwrap();
242        let jwk = Jwk::from_public_key(&kp.public).unwrap();
243        assert_eq!(jwk.kty, "OKP");
244        assert_eq!(jwk.crv.as_deref(), Some("Ed25519"));
245        assert!(jwk.d.is_none());
246        let restored = jwk.to_public_key().unwrap();
247        assert_eq!(restored.bytes, kp.public.bytes);
248    }
249
250    #[test]
251    fn public_key_roundtrip_p256() {
252        let kp = KeyPair::generate(KeyType::P256).unwrap();
253        let jwk = Jwk::from_public_key(&kp.public).unwrap();
254        assert_eq!(jwk.kty, "EC");
255        assert_eq!(jwk.crv.as_deref(), Some("P-256"));
256        let restored = jwk.to_public_key().unwrap();
257        assert_eq!(restored.bytes, kp.public.bytes);
258    }
259
260    #[test]
261    fn public_key_roundtrip_p384() {
262        let kp = KeyPair::generate(KeyType::P384).unwrap();
263        let jwk = Jwk::from_public_key(&kp.public).unwrap();
264        assert_eq!(jwk.kty, "EC");
265        assert_eq!(jwk.crv.as_deref(), Some("P-384"));
266        let restored = jwk.to_public_key().unwrap();
267        assert_eq!(restored.bytes, kp.public.bytes);
268    }
269
270    #[test]
271    fn public_key_roundtrip_secp256k1() {
272        let kp = KeyPair::generate(KeyType::Secp256k1).unwrap();
273        let jwk = Jwk::from_public_key(&kp.public).unwrap();
274        assert_eq!(jwk.kty, "EC");
275        assert_eq!(jwk.crv.as_deref(), Some("secp256k1"));
276        let restored = jwk.to_public_key().unwrap();
277        assert_eq!(restored.bytes, kp.public.bytes);
278    }
279
280    #[test]
281    fn key_pair_roundtrip_ed25519() {
282        let kp = KeyPair::generate(KeyType::Ed25519).unwrap();
283        let jwk = Jwk::from_key_pair(&kp).unwrap();
284        assert!(jwk.d.is_some());
285        let restored = jwk.to_key_pair().unwrap();
286        assert_eq!(restored.public.bytes, kp.public.bytes);
287        assert_eq!(restored.secret_bytes(), kp.secret_bytes());
288    }
289
290    #[test]
291    fn key_pair_roundtrip_p256() {
292        let kp = KeyPair::generate(KeyType::P256).unwrap();
293        let jwk = Jwk::from_key_pair(&kp).unwrap();
294        let restored = jwk.to_key_pair().unwrap();
295        assert_eq!(restored.public.bytes, kp.public.bytes);
296        assert_eq!(restored.secret_bytes(), kp.secret_bytes());
297    }
298
299    #[test]
300    fn key_pair_roundtrip_secp256k1() {
301        let kp = KeyPair::generate(KeyType::Secp256k1).unwrap();
302        let jwk = Jwk::from_key_pair(&kp).unwrap();
303        let restored = jwk.to_key_pair().unwrap();
304        assert_eq!(restored.public.bytes, kp.public.bytes);
305        assert_eq!(restored.secret_bytes(), kp.secret_bytes());
306    }
307
308    #[test]
309    fn to_key_pair_requires_d_field() {
310        let kp = KeyPair::generate(KeyType::Ed25519).unwrap();
311        let jwk = Jwk::from_public_key(&kp.public).unwrap();
312        assert!(jwk.to_key_pair().is_err());
313    }
314
315    #[test]
316    fn jwk_json_serialization() {
317        let kp = KeyPair::generate(KeyType::Ed25519).unwrap();
318        let jwk = Jwk::from_public_key(&kp.public).unwrap();
319        let json = serde_json::to_string(&jwk).unwrap();
320        assert!(json.contains("\"kty\":\"OKP\""));
321        assert!(json.contains("\"crv\":\"Ed25519\""));
322        // d should not appear in public key JWK
323        assert!(!json.contains("\"d\""));
324    }
325
326    #[test]
327    fn unsupported_curve_rejected() {
328        let jwk = Jwk {
329            kty: "EC".to_string(),
330            crv: Some("P-521".to_string()),
331            x: Some("AAAA".to_string()),
332            y: Some("BBBB".to_string()),
333            d: None,
334            kid: None,
335            key_use: None,
336            alg: None,
337        };
338        assert!(jwk.to_public_key().is_err());
339    }
340
341    #[test]
342    fn key_pair_roundtrip_p384() {
343        let kp = KeyPair::generate(KeyType::P384).unwrap();
344        let jwk = Jwk::from_key_pair(&kp).unwrap();
345        let restored = jwk.to_key_pair().unwrap();
346        assert_eq!(restored.public.bytes, kp.public.bytes);
347        assert_eq!(restored.secret_bytes(), kp.secret_bytes());
348    }
349
350    #[test]
351    fn missing_crv_rejected() {
352        let jwk = Jwk {
353            kty: "OKP".to_string(),
354            crv: None,
355            x: Some("AAAA".to_string()),
356            y: None,
357            d: None,
358            kid: None,
359            key_use: None,
360            alg: None,
361        };
362        assert!(jwk.to_public_key().is_err());
363    }
364
365    #[test]
366    fn missing_x_rejected() {
367        let jwk = Jwk {
368            kty: "OKP".to_string(),
369            crv: Some("Ed25519".to_string()),
370            x: None,
371            y: None,
372            d: None,
373            kid: None,
374            key_use: None,
375            alg: None,
376        };
377        assert!(jwk.to_public_key().is_err());
378    }
379
380    #[test]
381    fn ec_missing_y_rejected() {
382        let kp = KeyPair::generate(KeyType::P256).unwrap();
383        let mut jwk = Jwk::from_public_key(&kp.public).unwrap();
384        jwk.y = None; // remove y coordinate
385        assert!(jwk.to_public_key().is_err());
386    }
387
388    #[test]
389    fn invalid_base64_x_rejected() {
390        let jwk = Jwk {
391            kty: "OKP".to_string(),
392            crv: Some("Ed25519".to_string()),
393            x: Some("!!!invalid-base64!!!".to_string()),
394            y: None,
395            d: None,
396            kid: None,
397            key_use: None,
398            alg: None,
399        };
400        assert!(jwk.to_public_key().is_err());
401    }
402
403    #[test]
404    fn unknown_kty_rejected() {
405        let jwk = Jwk {
406            kty: "RSA".to_string(),
407            crv: Some("Ed25519".to_string()),
408            x: Some("AAAA".to_string()),
409            y: None,
410            d: None,
411            kid: None,
412            key_use: None,
413            alg: None,
414        };
415        assert!(jwk.to_public_key().is_err());
416    }
417}