from typing import Any, Dict, List, Optional, TypeVar
from cbor2 import CBORTag, loads
from .cbor_processor import CBORProcessor
from .const import COSE_TAG_TO_TYPE, COSE_TYPE_TO_TAG
from .cose_key_interface import COSEKeyInterface
from .enums import COSEType
from .signer import Signer
Self = TypeVar("Self", bound="COSEMessage")
[docs]class COSEMessage(CBORProcessor):
"""
The COSE message.
"""
[docs] def __init__(self, type: COSEType, msg: List[Any]):
"""
Constructor.
Args:
type (List[Any]): A type of the COSE message.
msg (List[Any]): A COSE message as a CBOR array.
"""
self._validate_cose_message(msg)
self._msg = msg
self._type = type
self._protected = msg[0]
self._unprotected = msg[1]
self._payload = msg[2]
self._otther_fields: List[bytes] = []
self._recipients: List[List[Any]] = []
self._signatures: List[List[Any]] = []
if self._type == COSEType.ENCRYPT0:
if len(self._msg) != 3:
raise ValueError("Invalid COSE_Encrypt0 message.")
elif self._type == COSEType.ENCRYPT:
if len(self._msg) != 4:
raise ValueError("Invalid COSE_Encrypt message.")
if not isinstance(self._msg[3], list):
raise ValueError("The COSE recipients should be array.")
for recipient in self._msg[3]:
self._validate_cose_message(recipient)
self._recipients = self._msg[3]
elif self._type == COSEType.MAC0:
if len(self._msg) != 4:
raise ValueError("Invalid COSE_Mac0 message.")
if not isinstance(self._msg[3], bytes):
raise ValueError("tag should be bytes.")
self._otther_fields = [self._msg[3]] # tag
elif self._type == COSEType.MAC:
if len(self._msg) != 5:
raise ValueError("Invalid COSE_Mac message.")
if not isinstance(self._msg[3], bytes):
raise ValueError("The tag value should be bytes.")
self._otther_fields = [self._msg[3]] # tag
if not isinstance(self._msg[4], list):
raise ValueError("The COSE recipients should be array.")
for recipient in self._msg[4]:
self._validate_cose_message(recipient)
self._recipients = self._msg[4]
elif self._type == COSEType.SIGN1:
if len(self._msg) != 4:
raise ValueError("Invalid COSE_Sign1 message.")
if not isinstance(self._msg[3], bytes):
raise ValueError("The COSE signature should be bytes.")
self._otther_fields = [self._msg[3]]
elif self._type == COSEType.SIGN:
if len(self._msg) != 4:
raise ValueError("Invalid COSE_Sign message.")
if not isinstance(self._msg[3], list):
raise ValueError("The COSE signatures should be array.")
for signature in self._msg[3]:
self._validate_cose_message(signature)
self._signatures = self._msg[3]
elif self._type == COSEType.COUNTERSIGNATURE:
if len(self._msg) != 3:
raise ValueError("Invalid COSE_Countersignature.")
elif self._type == COSEType.SIGNATURE:
if len(self._msg) != 3:
raise ValueError("Invalid COSE_Signature.")
elif self._type == COSEType.RECIPIENT:
if len(self._msg) != 3:
raise ValueError("Invalid COSE_Recipient.")
else:
raise ValueError(f"Invalid COSEType({type}) for COSE message.")
return
[docs] @classmethod
def loads(cls, msg: bytes):
tagged = loads(msg)
if not isinstance(tagged, CBORTag):
raise ValueError("Invalid COSE format.")
type = COSE_TAG_TO_TYPE.get(tagged.tag, None)
if type is None:
raise ValueError(f"Unknown CBOR tag for COSE message: {tagged.tag}.")
return cls(type, tagged.value)
[docs] @classmethod
def from_cose_signature(cls, signature: List[Any]):
return cls(COSEType.SIGNATURE, signature)
[docs] @classmethod
def from_cose_recipient(cls, recipient: List[Any]):
return cls(COSEType.RECIPIENT, recipient)
@property
def type(self) -> COSEType:
"""
The identifier of the key type.
"""
return self._type
@property
def protected(self) -> Dict[int, Any]:
"""
The protected headers as a CBOR object.
"""
return self._loads(self._protected)
@property
def unprotected(self) -> Dict[int, Any]:
"""
The unprotected headers as a CBOR object.
"""
return self._unprotected
@property
def payload(self) -> bytes:
"""
The payload of the COSE message.
"""
return self._payload
@property
def other_fields(self) -> List[bytes]:
"""
The list of other fields of the COSE message.
"""
return self._otther_fields
@property
def signatures(self) -> List[List[Any]]:
"""
The list of signatures of the COSE message.
"""
return self._signatures
@property
def recipients(self) -> List[List[Any]]:
"""
The list of recipients of the COSE message.
"""
return self._recipients
[docs] def dumps(self) -> bytes:
"""
Serializes the COSE message structure to a byte string.
"""
tag = COSE_TYPE_TO_TAG.get(self._type, -1)
return self._dumps(CBORTag(tag, self._msg)) if tag > 0 else self._dumps(self._msg)
[docs] def countersign(self: Self, signer: Signer, aad: bytes = b"", abbreviated: bool = False, tagged: bool = False) -> Self:
"""
Countersigns to the COSE message with the signer specified.
Args:
signer(Signer): A signer object that signs the COSE message.
aad (bytes): The application supplied additional authenticated data.
abbreviated(bool): The type of the countersignature (abbreviated or not).
tagged(bool): The indicator whether the countersignature is tagged or not.
Returns:
COSEMessage: The COSE message (self).
Raises:
ValueError: Invalid arguments.
EncodeError: Failed to countersign.
"""
if abbreviated:
to_be_signed = ["CounterSignature0V2", self._protected, aad, self._payload]
for other_field in self._otther_fields:
to_be_signed.append(other_field)
signer.sign(self._dumps(to_be_signed))
self._unprotected[12] = signer.signature
return self
to_be_signed = ["CounterSignatureV2", self._protected, signer.protected, aad, self._payload]
for other_field in self._otther_fields:
to_be_signed.append(other_field)
signer.sign(self._dumps(to_be_signed))
cs = self._unprotected.get(11, None)
if not cs:
self._unprotected[11] = [signer.protected, signer.unprotected, signer.signature]
return self
if isinstance(cs[0], bytes):
self._unprotected[11] = [cs]
self._unprotected[11].append([signer.protected, signer.unprotected, signer.signature])
return self
[docs] def counterverify(self, key: COSEKeyInterface, aad: bytes = b"") -> Optional[List[Any]]:
"""
Verifies a countersignature in the COSE message with the verification key specified.
Args:
key(COSEKeyInterface): A COSEKey that is used to verify a signature in the COSE message.
aad (bytes): The application supplied additional authenticated data.
Returns:
Optional[List[Any]]: The COSE signature verified.
Raises:
ValueError: Invalid arguments.
VerifyError: Failed to verify.
"""
err: Exception = ValueError("Countersignature not found.")
acs = self._unprotected.get(12, None)
if acs:
to_be_signed = ["CounterSignature0V2", self._protected, aad, self._payload]
for other_field in self._otther_fields:
to_be_signed.append(other_field)
try:
key.verify(self._dumps(to_be_signed), acs)
return None
except Exception as e:
err = e
cs = self._unprotected.get(11, None)
if not cs:
raise err
to_be_signed = ["CounterSignatureV2", self._protected, b"", aad, self._payload]
for other_field in self._otther_fields:
to_be_signed.append(other_field)
if isinstance(cs[0], bytes):
kid = self._get_kid(cs)
if key.kid and kid and key.kid != kid:
raise ValueError("kid mismatch.")
to_be_signed[2] = cs[0]
key.verify(self._dumps(to_be_signed), cs[2])
return cs
for sig in cs:
kid = self._get_kid(sig)
if key.kid and kid and key.kid != kid:
continue
to_be_signed[2] = sig[0]
try:
key.verify(self._dumps(to_be_signed), sig[2])
return sig
except Exception as e:
err = e
raise err
def _validate_cose_message(self, msg: List[Any]):
if len(msg) < 3:
raise ValueError("Invalid COSE message.")
if not isinstance(msg[0], bytes):
raise ValueError("The protected headers should be bytes.")
if not isinstance(msg[1], dict):
raise ValueError("The unprotected headers should be Dict[int, Any].")
if not isinstance(msg[2], bytes):
raise ValueError("The payload should be bytes.")
countersignatures = msg[1].get(11, None)
if countersignatures is None:
return
if not isinstance(countersignatures, list):
raise ValueError("The countersignature should be array.")
if len(countersignatures) == 0:
raise ValueError("Invalid countersignature.")
if isinstance(countersignatures[0], bytes):
return self._validate_cose_message(countersignatures)
elif not isinstance(countersignatures[0], list):
raise ValueError("Invalid countersignature.")
for cs in countersignatures:
self._validate_cose_message(cs)
return
def _get_kid(self, sig: list) -> Optional[bytes]:
kid = sig[1].get(4, None)
return kid if kid else self._loads(sig[0]).get(4, None)