diff --git a/key/key.go b/key/key.go index 63aab05..24718c1 100644 --- a/key/key.go +++ b/key/key.go @@ -57,9 +57,29 @@ func (k *Key) ecdsaPubKey() (*ecdsa.PublicKey, error) { return nil, err } - ecdsaKey := &ecdsa.PublicKey{Curve: elliptic.P256(), - X: big.NewInt(0).SetBytes(ecc.X.Buffer), - Y: big.NewInt(0).SetBytes(ecc.Y.Buffer), + eccdeets, err := pub.Parameters.ECCDetail() + if err != nil { + return nil, err + } + + var ecdsaKey *ecdsa.PublicKey + + switch eccdeets.CurveID { + case tpm2.TPMECCNistP256: + ecdsaKey = &ecdsa.PublicKey{Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(ecc.X.Buffer), + Y: big.NewInt(0).SetBytes(ecc.Y.Buffer), + } + case tpm2.TPMECCNistP384: + ecdsaKey = &ecdsa.PublicKey{Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(ecc.X.Buffer), + Y: big.NewInt(0).SetBytes(ecc.Y.Buffer), + } + case tpm2.TPMECCNistP521: + ecdsaKey = &ecdsa.PublicKey{Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(ecc.X.Buffer), + Y: big.NewInt(0).SetBytes(ecc.Y.Buffer), + } } return ecdsaKey, nil @@ -400,15 +420,30 @@ func CreateKey(tpm transport.TPMCloser, keytype tpm2.TPMAlgID, bits int, pin, co } func ImportKey(tpm transport.TPMCloser, pk any, pin, comment []byte) (*Key, error) { - var public tpm2.TPMTPublic var sensitive tpm2.TPMTSensitive var unique tpm2.TPMUPublicID var keytype tpm2.TPMAlgID + supportedECCBitsizes := SupportedECCAlgorithms(tpm) + switch p := pk.(type) { case ecdsa.PrivateKey: + var curveid tpm2.TPMECCCurve + + if !slices.Contains(supportedECCBitsizes, p.Params().BitSize) { + return nil, fmt.Errorf("invalid ecdsa key length: TPM does not support %v bits", p.Params().BitSize) + } + + switch p.Params().BitSize { + case 256: + curveid = tpm2.TPMECCNistP256 + case 384: + curveid = tpm2.TPMECCNistP384 + case 512: + curveid = tpm2.TPMECCNistP521 + } keytype = tpm2.TPMAlgECDSA @@ -417,7 +452,7 @@ func ImportKey(tpm transport.TPMCloser, pk any, pin, comment []byte) (*Key, erro SensitiveType: tpm2.TPMAlgECC, Sensitive: tpm2.NewTPMUSensitiveComposite( tpm2.TPMAlgECC, - &tpm2.TPM2BECCParameter{Buffer: p.D.FillBytes(make([]byte, 32))}, + &tpm2.TPM2BECCParameter{Buffer: p.D.FillBytes(make([]byte, len(p.D.Bytes())))}, ), } @@ -425,10 +460,10 @@ func ImportKey(tpm transport.TPMCloser, pk any, pin, comment []byte) (*Key, erro tpm2.TPMAlgECC, &tpm2.TPMSECCPoint{ X: tpm2.TPM2BECCParameter{ - Buffer: p.X.FillBytes(make([]byte, 32)), + Buffer: p.X.FillBytes(make([]byte, len(p.X.Bytes()))), }, Y: tpm2.TPM2BECCParameter{ - Buffer: p.Y.FillBytes(make([]byte, 32)), + Buffer: p.Y.FillBytes(make([]byte, len(p.Y.Bytes()))), }, }, ) @@ -443,7 +478,7 @@ func ImportKey(tpm transport.TPMCloser, pk any, pin, comment []byte) (*Key, erro Parameters: tpm2.NewTPMUPublicParms( tpm2.TPMAlgECC, &tpm2.TPMSECCParms{ - CurveID: tpm2.TPMECCNistP256, + CurveID: curveid, Scheme: tpm2.TPMTECCScheme{ Scheme: tpm2.TPMAlgECDSA, Details: tpm2.NewTPMUAsymScheme( @@ -459,6 +494,8 @@ func ImportKey(tpm transport.TPMCloser, pk any, pin, comment []byte) (*Key, erro } case rsa.PrivateKey: + // TODO: Reject larger keys than 2048 + keytype = tpm2.TPMAlgRSA // Prepare RSA key for importing diff --git a/key/key_test.go b/key/key_test.go index a86f6ae..b9051b6 100644 --- a/key/key_test.go +++ b/key/key_test.go @@ -25,6 +25,11 @@ func TestCreateKey(t *testing.T) { alg: tpm2.TPMAlgECDSA, bits: 256, }, + { + text: "p384", + alg: tpm2.TPMAlgECDSA, + bits: 384, + }, { text: "rsa", alg: tpm2.TPMAlgRSA, @@ -150,50 +155,134 @@ func TestMarshalling(t *testing.T) { } } -func TestECDSAImportKey(t *testing.T) { - pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) +func mkRSA(t *testing.T, bits int) rsa.PrivateKey { + t.Helper() + pk, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + t.Fatalf("failed to generate rsa key: %v", err) + } + return *pk +} + +func mkECDSA(t *testing.T, a elliptic.Curve) ecdsa.PrivateKey { + t.Helper() + pk, err := ecdsa.GenerateKey(a, rand.Reader) if err != nil { t.Fatalf("failed to generate ecdsa key: %v", err) } + return *pk +} +func TestImport(t *testing.T) { tpm, err := simulator.OpenSimulator() if err != nil { t.Fatal(err) } defer tpm.Close() - k, err := ImportKey(tpm, *pk, []byte(""), []byte("")) - if err != nil { - t.Fatalf("failed key import: %v", err) - } - // Test if we can load the key - // signer/signer_test.go tests the signing of the key - _, err = LoadKey(tpm, k) - if err != nil { - t.Fatalf("failed loading key: %v", err) - } -} + for _, c := range []struct { + text string + pk any + fail bool + }{ + { + text: "p256", + pk: mkECDSA(t, elliptic.P256()), + }, + { + text: "p384", + pk: mkECDSA(t, elliptic.P384()), + }, + { + text: "p521", + pk: mkECDSA(t, elliptic.P521()), + // Simulator doesn't like P521 + fail: true, + }, + { + text: "rsa2048", + pk: mkRSA(t, 2048), + }, + } { + t.Run(c.text, func(t *testing.T) { + k, err := ImportKey(tpm, c.pk, []byte(""), []byte("")) + if err != nil && c.fail { + return + } + if err != nil { + t.Fatalf("failed key import: %v", err) + } -func TestRSAImportKey(t *testing.T) { - pk, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("failed to generate rsa key: %v", err) + // Test if we can load the key + // signer/signer_test.go tests the signing of the key + handle, err := LoadKey(tpm, k) + if err != nil { + t.Fatalf("failed loading key: %v", err) + } + utils.FlushHandle(tpm, handle) + }) } +} +func TestKeyPublickey(t *testing.T) { tpm, err := simulator.OpenSimulator() if err != nil { t.Fatal(err) } defer tpm.Close() - k, err := ImportKey(tpm, *pk, []byte(""), []byte("")) - if err != nil { - t.Fatalf("failed key import: %v", err) - } - // Test if we can load the key - // signer/signer_test.go tests the signing of the key - _, err = LoadKey(tpm, k) - if err != nil { - t.Fatalf("failed loading key: %v", err) + for _, c := range []struct { + text string + pk any + bitlength int + fail bool + }{ + { + text: "p256", + pk: mkECDSA(t, elliptic.P256()), + bitlength: 256, + }, + { + text: "p384", + pk: mkECDSA(t, elliptic.P384()), + bitlength: 384, + }, + { + text: "p521", + pk: mkECDSA(t, elliptic.P521()), + // Simulator doesn't like P521 + bitlength: 521, + fail: true, + }, + { + text: "rsa2048", + pk: mkRSA(t, 2048), + bitlength: 2048, + }, + } { + t.Run(c.text, func(t *testing.T) { + k, err := ImportKey(tpm, c.pk, []byte(""), []byte("")) + if err != nil && c.fail { + return + } + if err != nil { + t.Fatalf("failed key import: %v", err) + } + + pubkey, err := k.PublicKey() + if err != nil { + t.Fatalf("failed getting public key: %v", err) + } + switch pk := pubkey.(type) { + case *ecdsa.PublicKey: + if pk.Params().BitSize != c.bitlength { + t.Fatalf("wrong import, expected %v got %v bitlength", pk.Params().BitSize, c.bitlength) + } + case *rsa.PublicKey: + if pk.N.BitLen() != c.bitlength { + t.Fatalf("wrong import, expected %v got %v bitlength", pk.N.BitLen(), c.bitlength) + } + } + }) } }