Skip to content

Commit ab0d2c9

Browse files
Refactor the aggregation check logic to also use it for deprecation consistency checking
1 parent ebe0734 commit ab0d2c9

File tree

6 files changed

+66
-49
lines changed

6 files changed

+66
-49
lines changed

pydsdl/_serializable/_array.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import typing
88
import warnings
99
from .._bit_length_set import BitLengthSet
10-
from ._serializable import SerializableType, TypeParameterError, AggregationError
10+
from ._serializable import SerializableType, TypeParameterError, AggregationFailure
1111
from ._primitive import UnsignedIntegerType, PrimitiveType
1212

1313

@@ -22,17 +22,16 @@ def __init__(self, element_type: SerializableType, capacity: int):
2222
self._capacity = int(capacity)
2323
if self._capacity < 1:
2424
raise InvalidNumberOfElementsError("Array capacity cannot be less than 1")
25-
if not self.element_type.is_valid_aggregate(self):
26-
raise AggregationError("Specified element type of %r is not valid" % str(self))
2725

2826
@property
2927
def deprecated(self) -> bool:
3028
return self.element_type.deprecated
3129

32-
def is_valid_aggregate(self, aggregate: SerializableType) -> bool:
33-
# At the moment, arrays of arrays are not allowed, but this is ensured by the grammar definition.
34-
# This restriction will be lifted eventually. All that needs to be done is to update the grammar only.
35-
return True
30+
def check_aggregation(self, aggregate: "SerializableType") -> typing.Optional[AggregationFailure]:
31+
af = self.element_type.check_aggregation(self)
32+
if af is not None:
33+
return AggregationFailure(self, aggregate, "Element type of %r is not valid: %s" % (str(self), af.message))
34+
return super().check_aggregation(aggregate)
3635

3736
@property
3837
def element_type(self) -> SerializableType:

pydsdl/_serializable/_composite.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from .. import _expression
1111
from .. import _port_id_ranges
1212
from .._bit_length_set import BitLengthSet
13-
from ._serializable import SerializableType, TypeParameterError, AggregationError
13+
from .._error import InvalidDefinitionError
14+
from ._serializable import SerializableType, TypeParameterError, AggregationFailure
1415
from ._attribute import Attribute, Field, PaddingField, Constant
1516
from ._name import check_name, InvalidNameError
1617
from ._primitive import PrimitiveType, UnsignedIntegerType
@@ -39,7 +40,7 @@ class MalformedUnionError(TypeParameterError):
3940
pass
4041

4142

42-
class DeprecatedDependencyError(AggregationError):
43+
class AggregationError(InvalidDefinitionError):
4344
pass
4445

4546

@@ -123,22 +124,13 @@ def __init__( # pylint: disable=too-many-arguments
123124
if not (0 <= port_id <= _port_id_ranges.MAX_SUBJECT_ID):
124125
raise InvalidFixedPortIDError("Fixed subject ID %r is not valid" % port_id)
125126

126-
# Consistent deprecation check.
127-
# A non-deprecated type cannot be dependent on deprecated types.
128-
# A deprecated type can be dependent on anything.
129-
if not self.deprecated:
130-
for a in self._attributes:
131-
if a.data_type.deprecated:
132-
raise DeprecatedDependencyError(
133-
"A type cannot depend on deprecated types unless it is also deprecated."
134-
)
135-
136127
# Aggregation check. For example:
137128
# - Types like utf8 and byte cannot be used outside of arrays.
138129
# - A non-deprecated type cannot depend on a deprecated type.
139130
for a in self._attributes:
140-
if not a.data_type.is_valid_aggregate(self):
141-
raise AggregationError("Type of %r is not a valid field type for %s" % (str(a), self))
131+
af = a.data_type.check_aggregation(self)
132+
if af is not None:
133+
raise AggregationError("Type of %r is not a valid field type for %s: %s" % (str(a), self, af.message))
142134

143135
@property
144136
def full_name(self) -> str:
@@ -196,12 +188,8 @@ def bit_length_set(self) -> BitLengthSet:
196188
"""
197189
raise NotImplementedError
198190

199-
def is_valid_aggregate(self, aggregate: SerializableType) -> bool:
200-
if self.deprecated: # Deprecation consistency check: non-deprecated cannot depend on a deprecated type.
201-
if isinstance(aggregate, CompositeType) and not aggregate.deprecated:
202-
return False
203-
204-
return True
191+
def check_aggregation(self, aggregate: "SerializableType") -> typing.Optional[AggregationFailure]:
192+
return super().check_aggregation(aggregate)
205193

206194
@property
207195
def deprecated(self) -> bool:
@@ -615,8 +603,11 @@ def iterate_fields_with_offsets(
615603
base_offset = base_offset + self.delimiter_header_type.bit_length_set
616604
return self.inner_type.iterate_fields_with_offsets(base_offset)
617605

618-
def is_valid_aggregate(self, aggregate: SerializableType) -> bool:
619-
return self.inner_type.is_valid_aggregate(aggregate)
606+
def check_aggregation(self, aggregate: "SerializableType") -> typing.Optional[AggregationFailure]:
607+
af = self.inner_type.check_aggregation(aggregate)
608+
if af is not None:
609+
return af
610+
return super().check_aggregation(aggregate)
620611

621612
def __repr__(self) -> str:
622613
return "%s(inner=%r, extent=%r)" % (self.__class__.__name__, self.inner_type, self.extent)

pydsdl/_serializable/_primitive.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import typing
1111
import fractions
1212
from .._bit_length_set import BitLengthSet
13-
from ._serializable import SerializableType, TypeParameterError
13+
from ._serializable import SerializableType, TypeParameterError, AggregationFailure
1414

1515

1616
ValueRange = typing.NamedTuple("ValueRange", [("min", fractions.Fraction), ("max", fractions.Fraction)])
@@ -56,8 +56,8 @@ def deprecated(self) -> bool:
5656
"""Primitive types cannot be deprecated."""
5757
return False
5858

59-
def is_valid_aggregate(self, aggregate: SerializableType) -> bool:
60-
return True
59+
def check_aggregation(self, aggregate: "SerializableType") -> typing.Optional[AggregationFailure]:
60+
return super().check_aggregation(aggregate)
6161

6262
@property
6363
def bit_length(self) -> int:
@@ -174,10 +174,12 @@ class ByteType(UnsignedIntegerType):
174174
def __init__(self) -> None:
175175
super().__init__(bit_length=PrimitiveType.BITS_IN_BYTE, cast_mode=PrimitiveType.CastMode.TRUNCATED)
176176

177-
def is_valid_aggregate(self, aggregate: SerializableType) -> bool:
177+
def check_aggregation(self, aggregate: "SerializableType") -> typing.Optional[AggregationFailure]:
178178
from ._array import ArrayType
179179

180-
return isinstance(aggregate, ArrayType)
180+
if not isinstance(aggregate, ArrayType):
181+
return AggregationFailure(self, aggregate, "The byte type can only be used as an array element type")
182+
return super().check_aggregation(aggregate)
181183

182184
def __str__(self) -> str:
183185
return "byte"
@@ -187,10 +189,14 @@ class UTF8Type(UnsignedIntegerType):
187189
def __init__(self) -> None:
188190
super().__init__(bit_length=8, cast_mode=PrimitiveType.CastMode.TRUNCATED)
189191

190-
def is_valid_aggregate(self, aggregate: SerializableType) -> bool:
192+
def check_aggregation(self, aggregate: "SerializableType") -> typing.Optional[AggregationFailure]:
191193
from ._array import VariableLengthArrayType
192194

193-
return isinstance(aggregate, VariableLengthArrayType)
195+
if not isinstance(aggregate, VariableLengthArrayType):
196+
return AggregationFailure(
197+
self, aggregate, "The utf8 type can only be used as a variable-length array element type"
198+
)
199+
return super().check_aggregation(aggregate)
194200

195201
def __str__(self) -> str:
196202
return "utf8"

pydsdl/_serializable/_serializable.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Author: Pavel Kirienko <[email protected]>
44

55
import abc
6+
from typing import Optional, NamedTuple
67
from .. import _expression
78
from .. import _error
89
from .._bit_length_set import BitLengthSet
@@ -12,8 +13,18 @@ class TypeParameterError(_error.InvalidDefinitionError):
1213
pass
1314

1415

15-
class AggregationError(_error.InvalidDefinitionError):
16-
pass
16+
AggregationFailure = NamedTuple(
17+
"AggregationFailure",
18+
[
19+
("inner", "SerializableType"),
20+
("outer", "SerializableType"),
21+
("message", str),
22+
],
23+
)
24+
"""
25+
This is returned by :meth:`SerializableType.check_aggregation()` to indicate the reason of the failure.
26+
Eventually this will need to be replaced with a dataclass (sadly no dataclasses in Python 3.6).
27+
"""
1728

1829

1930
class SerializableType(_expression.Any):
@@ -69,14 +80,20 @@ def deprecated(self) -> bool:
6980
"""
7081
raise NotImplementedError
7182

72-
@abc.abstractmethod
73-
def is_valid_aggregate(self, aggregate: "SerializableType") -> bool:
83+
def check_aggregation(self, aggregate: "SerializableType") -> Optional[AggregationFailure]:
7484
"""
75-
Returns true iff the argument is a suitable aggregate type for this element type.
85+
Returns None iff the argument is a suitable aggregate type for this element type;
86+
otherwise, returns an instance of :class:`AggregationFailure` describing the reason of the failure.
7687
This is needed to detect incorrect aggregations, such as padding fields in unions,
77-
or standalone utf8 or byte, etc.
88+
or standalone utf8 or byte, or a deprecated type nested into a non-deprecated type, etc.
7889
"""
79-
raise NotImplementedError
90+
if self.deprecated and not aggregate.deprecated:
91+
return AggregationFailure(
92+
self,
93+
aggregate,
94+
"A non-deprecated type %s cannot depend on a deprecated type %s" % (aggregate, self),
95+
)
96+
return None
8097

8198
def _attribute(self, name: _expression.String) -> _expression.Any:
8299
if name.native_value == "_bit_length_": # Experimental non-standard extension

pydsdl/_serializable/_void.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# This software is distributed under the terms of the MIT License.
33
# Author: Pavel Kirienko <[email protected]>
44

5+
import typing
56
from .._bit_length_set import BitLengthSet
6-
from ._serializable import SerializableType
7+
from ._serializable import SerializableType, AggregationFailure
78
from ._primitive import InvalidBitLengthError
89

910

@@ -29,11 +30,12 @@ def deprecated(self) -> bool:
2930
"""Void types cannot be deprecated."""
3031
return False
3132

32-
def is_valid_aggregate(self, aggregate: SerializableType) -> bool:
33+
def check_aggregation(self, aggregate: "SerializableType") -> typing.Optional[AggregationFailure]:
3334
from ._composite import StructureType, CompositeType
3435

35-
# Unions are not allowed to contain void types. Only StructureType is allowed to do so.
36-
return isinstance(aggregate, CompositeType) and isinstance(aggregate.inner_type, StructureType)
36+
if not isinstance(aggregate, CompositeType) or not isinstance(aggregate.inner_type, StructureType):
37+
return AggregationFailure(self, aggregate, "Void types can only be aggregated into structures")
38+
return super().check_aggregation(aggregate)
3739

3840
@property
3941
def bit_length(self) -> int:

pydsdl/_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,7 +1696,7 @@ def _unittest_inconsistent_deprecation(wrkspc: Workspace) -> None:
16961696
],
16971697
)
16981698

1699-
with raises(_error.InvalidDefinitionError, match="(?i).*depend.*deprecated.*"):
1699+
with raises(_error.InvalidDefinitionError, match="(?i).*depend.*deprecated.*") as exc_info:
17001700
parse_definition(
17011701
wrkspc.parse_new(
17021702
"ns/C.1.0.dsdl",
@@ -1709,8 +1709,9 @@ def _unittest_inconsistent_deprecation(wrkspc: Workspace) -> None:
17091709
),
17101710
[wrkspc.parse_new("ns/X.1.0.dsdl", "@deprecated\n@sealed")],
17111711
)
1712+
print(exc_info.value.text)
17121713

1713-
with raises(_error.InvalidDefinitionError, match="(?i).*depend.*deprecated.*"):
1714+
with raises(_error.InvalidDefinitionError, match="(?i).*depend.*deprecated.*") as exc_info:
17141715
parse_definition(
17151716
wrkspc.parse_new(
17161717
"ns/C.1.0.dsdl",
@@ -1723,6 +1724,7 @@ def _unittest_inconsistent_deprecation(wrkspc: Workspace) -> None:
17231724
),
17241725
[wrkspc.parse_new("ns/X.1.0.dsdl", "@deprecated\n@sealed")],
17251726
)
1727+
print(exc_info.value.text)
17261728

17271729
parse_definition(
17281730
wrkspc.parse_new(

0 commit comments

Comments
 (0)