Source code for cwt.cose_message

from typing import Any, Dict, List, Optional, TypeVar

from cbor2 import CBORTag, loads

from .cbor_processor import CBORProcessor
from .const import COSE_TAG_TO_TYPE, COSE_TYPE_TO_TAG
from .cose_key_interface import COSEKeyInterface
from .enums import COSETypes
from .signer import Signer

Self = TypeVar("Self", bound="COSEMessage")


[docs]class COSEMessage(CBORProcessor): """ The COSE message. """
[docs] def __init__(self, type: COSETypes, msg: List[Any]): """ Constructor. Args: type (List[Any]): A type of the COSE message. msg (List[Any]): A COSE message as a CBOR array. """ self._validate_cose_message(msg) self._msg = msg self._type = type self._protected = msg[0] self._unprotected = msg[1] self._payload = msg[2] self._otther_fields: List[bytes] = [] self._recipients: List[List[Any]] = [] self._signatures: List[List[Any]] = [] if self._type == COSETypes.ENCRYPT0: if len(self._msg) != 3: raise ValueError("Invalid COSE_Encrypt0 message.") elif self._type == COSETypes.ENCRYPT: if len(self._msg) != 4: raise ValueError("Invalid COSE_Encrypt message.") if not isinstance(self._msg[3], list): raise ValueError("The COSE recipients should be array.") for recipient in self._msg[3]: self._validate_cose_message(recipient) self._recipients = self._msg[3] elif self._type == COSETypes.MAC0: if len(self._msg) != 4: raise ValueError("Invalid COSE_Mac0 message.") if not isinstance(self._msg[3], bytes): raise ValueError("tag should be bytes.") self._otther_fields = [self._msg[3]] # tag elif self._type == COSETypes.MAC: if len(self._msg) != 5: raise ValueError("Invalid COSE_Mac message.") if not isinstance(self._msg[3], bytes): raise ValueError("The tag value should be bytes.") self._otther_fields = [self._msg[3]] # tag if not isinstance(self._msg[4], list): raise ValueError("The COSE recipients should be array.") for recipient in self._msg[4]: self._validate_cose_message(recipient) self._recipients = self._msg[4] elif self._type == COSETypes.SIGN1: if len(self._msg) != 4: raise ValueError("Invalid COSE_Sign1 message.") if not isinstance(self._msg[3], bytes): raise ValueError("The COSE signature should be bytes.") self._otther_fields = [self._msg[3]] elif self._type == COSETypes.SIGN: if len(self._msg) != 4: raise ValueError("Invalid COSE_Sign message.") if not isinstance(self._msg[3], list): raise ValueError("The COSE signatures should be array.") for signature in self._msg[3]: self._validate_cose_message(signature) self._signatures = self._msg[3] elif self._type == COSETypes.COUNTERSIGNATURE: if len(self._msg) != 3: raise ValueError("Invalid COSE_Countersignature.") elif self._type == COSETypes.SIGNATURE: if len(self._msg) != 3: raise ValueError("Invalid COSE_Signature.") elif self._type == COSETypes.RECIPIENT: if len(self._msg) != 3: raise ValueError("Invalid COSE_Recipient.") else: raise ValueError(f"Invalid COSETypes({type}) for COSE message.") return
[docs] @classmethod def loads(cls, msg: bytes): tagged = loads(msg) if not isinstance(tagged, CBORTag): raise ValueError("Invalid COSE format.") type = COSE_TAG_TO_TYPE.get(tagged.tag, None) if type is None: raise ValueError(f"Unknown CBOR tag for COSE message: {tagged.tag}.") return cls(type, tagged.value)
[docs] @classmethod def from_cose_signature(cls, signature: List[Any]): return cls(COSETypes.SIGNATURE, signature)
[docs] @classmethod def from_cose_recipient(cls, recipient: List[Any]): return cls(COSETypes.RECIPIENT, recipient)
@property def type(self) -> COSETypes: """ The identifier of the key type. """ return self._type @property def protected(self) -> Dict[int, Any]: """ The protected headers as a CBOR object. """ return self._loads(self._protected) @property def unprotected(self) -> Dict[int, Any]: """ The unprotected headers as a CBOR object. """ return self._unprotected @property def payload(self) -> bytes: """ The payload of the COSE message. """ return self._payload @property def other_fields(self) -> List[bytes]: """ The list of other fields of the COSE message. """ return self._otther_fields @property def signatures(self) -> List[List[Any]]: """ The list of signatures of the COSE message. """ return self._signatures @property def recipients(self) -> List[List[Any]]: """ The list of recipients of the COSE message. """ return self._recipients
[docs] def dumps(self) -> bytes: """ Serializes the COSE message structure to a byte string. """ tag = COSE_TYPE_TO_TAG.get(self._type, -1) return self._dumps(CBORTag(tag, self._msg)) if tag > 0 else self._dumps(self._msg)
[docs] def countersign(self: Self, signer: Signer, aad: bytes = b"", abbreviated: bool = False, tagged: bool = False) -> Self: """ Countersigns to the COSE message with the signer specified. Args: signer(Signer): A signer object that signs the COSE message. aad (bytes): The application supplied additional authenticated data. abbreviated(bool): The type of the countersignature (abbreviated or not). tagged(bool): The indicator whether the countersignature is tagged or not. Returns: COSEMessage: The COSE message (self). Raises: ValueError: Invalid arguments. EncodeError: Failed to countersign. """ if abbreviated: to_be_signed = ["CounterSignature0V2", self._protected, aad, self._payload] for other_field in self._otther_fields: to_be_signed.append(other_field) signer.sign(self._dumps(to_be_signed)) self._unprotected[12] = signer.signature return self to_be_signed = ["CounterSignatureV2", self._protected, signer.protected, aad, self._payload] for other_field in self._otther_fields: to_be_signed.append(other_field) signer.sign(self._dumps(to_be_signed)) cs = self._unprotected.get(11, None) if not cs: self._unprotected[11] = [signer.protected, signer.unprotected, signer.signature] return self if isinstance(cs[0], bytes): self._unprotected[11] = [cs] self._unprotected[11].append([signer.protected, signer.unprotected, signer.signature]) return self
[docs] def counterverify(self, key: COSEKeyInterface, aad: bytes = b"") -> Optional[List[Any]]: """ Verifies a countersignature in the COSE message with the verification key specified. Args: key(COSEKeyInterface): A COSEKey that is used to verify a signature in the COSE message. aad (bytes): The application supplied additional authenticated data. Returns: Optional[List[Any]]: The COSE signature verified. Raises: ValueError: Invalid arguments. VerifyError: Failed to verify. """ err: Exception = ValueError("Countersignature not found.") acs = self._unprotected.get(12, None) if acs: to_be_signed = ["CounterSignature0V2", self._protected, aad, self._payload] for other_field in self._otther_fields: to_be_signed.append(other_field) try: key.verify(self._dumps(to_be_signed), acs) return None except Exception as e: err = e cs = self._unprotected.get(11, None) if not cs: raise err to_be_signed = ["CounterSignatureV2", self._protected, b"", aad, self._payload] for other_field in self._otther_fields: to_be_signed.append(other_field) if isinstance(cs[0], bytes): kid = self._get_kid(cs) if key.kid and kid and key.kid != kid: raise ValueError("kid mismatch.") to_be_signed[2] = cs[0] key.verify(self._dumps(to_be_signed), cs[2]) return cs for sig in cs: kid = self._get_kid(sig) if key.kid and kid and key.kid != kid: continue to_be_signed[2] = sig[0] try: key.verify(self._dumps(to_be_signed), sig[2]) return sig except Exception as e: err = e raise err
def _validate_cose_message(self, msg: List[Any]): if len(msg) < 3: raise ValueError("Invalid COSE message.") if not isinstance(msg[0], bytes): raise ValueError("The protected headers should be bytes.") if not isinstance(msg[1], dict): raise ValueError("The unprotected headers should be Dict[int, Any].") if not isinstance(msg[2], bytes): raise ValueError("The payload should be bytes.") countersignatures = msg[1].get(11, None) if countersignatures is None: return if not isinstance(countersignatures, list): raise ValueError("The countersignature should be array.") if len(countersignatures) == 0: raise ValueError("Invalid countersignature.") if isinstance(countersignatures[0], bytes): return self._validate_cose_message(countersignatures) elif not isinstance(countersignatures[0], list): raise ValueError("Invalid countersignature.") for cs in countersignatures: self._validate_cose_message(cs) return def _get_kid(self, sig: list) -> Optional[bytes]: kid = sig[1].get(4, None) return kid if kid else self._loads(sig[0]).get(4, None)