diff --git a/sasdata/quantities/_units_base.py b/sasdata/quantities/_units_base.py index 5543f1d4e..a18657782 100644 --- a/sasdata/quantities/_units_base.py +++ b/sasdata/quantities/_units_base.py @@ -1,10 +1,9 @@ -from collections.abc import Sequence -from dataclasses import dataclass +import re from fractions import Fraction from typing import Self import numpy as np -from unicode_superscript import int_as_unicode_superscript +from unicode_superscript import int_as_unicode_superscript # type: ignore[import-untyped] class DimensionError(Exception): @@ -111,15 +110,15 @@ def __pow__(self, power: int | float): (self.moles_hint * numerator) // denominator, (self.angle_hint * numerator) // denominator) - def __eq__(self: Self, other: Self): + def __eq__(self: Self, other: object) -> bool: if isinstance(other, Dimensions): - return (self.length == other.length and - self.time == other.time and - self.mass == other.mass and - self.current == other.current and - self.temperature == other.temperature and - self.moles_hint == other.moles_hint and - self.angle_hint == other.angle_hint) + return (self.length == other.length + and self.time == other.time + and self.mass == other.mass + and self.current == other.current + and self.temperature == other.temperature + and self.moles_hint == other.moles_hint + and self.angle_hint == other.angle_hint) return NotImplemented @@ -210,9 +209,6 @@ def __init__(self, self.scale = si_scaling_factor self.dimensions = dimensions - def _components(self, tokens: Sequence["UnitToken"]): - pass - def __mul__(self: Self, other: "Unit"): if isinstance(other, Unit): return Unit(self.scale * other.scale, self.dimensions * other.dimensions) @@ -246,17 +242,15 @@ def __pow__(self, power: int | float): def equivalent(self: Self, other: "Unit"): return self.dimensions == other.dimensions - def __eq__(self: Self, other: "Unit"): - return self.equivalent(other) and np.abs(np.log(self.scale/other.scale)) < 1e-5 + def __eq__(self: Self, other: object) -> bool: + if isinstance(other, Unit): + return self.equivalent(other) and np.abs(np.log(self.scale/other.scale)) < 1e-5 + return False def si_equivalent(self): """ Get the SI unit corresponding to this unit""" return Unit(1, self.dimensions) - def _format_unit(self, format_process: list["UnitFormatProcessor"]): - for processor in format_process: - pass - def __repr__(self): if self.scale == 1: # We're in SI @@ -265,9 +259,6 @@ def __repr__(self): else: return f"Unit[{self.scale}, {self.dimensions}]" - @staticmethod - def parse(unit_string: str) -> "Unit": - pass class NamedUnit(Unit): """ Units, but they have a name, and a symbol @@ -308,57 +299,204 @@ def __eq__(self, other): case _: return False - def startswith(self, prefix: str) -> bool: """Check if any representation of the unit begins with the prefix string""" prefix = prefix.lower() return (self.name is not None and self.name.lower().startswith(prefix)) \ - or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \ - or (self.symbol is not None and self.symbol.lower().startswith(prefix)) - -# -# Parsing plan: -# Require unknown amounts of units to be explicitly positive or negative? -# -# - - - -@dataclass -class ProcessedUnitToken: - """ Mid processing representation of formatted units """ - base_string: str - exponent_string: str - latex_exponent_string: str - exponent: int - -class UnitFormatProcessor: - """ Represents a step in the unit processing pipeline""" - def apply(self, scale, dimensions) -> tuple[ProcessedUnitToken, float, Dimensions]: - """ This will be called to deal with each processing stage""" - -class RequiredUnitFormatProcessor(UnitFormatProcessor): - """ This unit is required to exist in the formatting """ - def __init__(self, unit: Unit, power: int = 1): - self.unit = unit - self.power = power - def apply(self, scale, dimensions) -> tuple[float, Dimensions, ProcessedUnitToken]: - new_scale = scale / (self.unit.scale * self.power) - new_dimensions = self.unit.dimensions / (dimensions**self.power) - token = ProcessedUnitToken(self.unit, self.power) - - return new_scale, new_dimensions, token -class GreedyAbsDimensionUnitFormatProcessor(UnitFormatProcessor): - """ This processor minimises the dimensionality of the unit by multiplying by as many - units of the specified type as needed """ - def __init__(self, unit: Unit): - self.unit = unit - - def apply(self, scale, dimensions) -> tuple[ProcessedUnitToken, float, Dimensions]: - pass - -class GreedyAbsDimensionUnitFormatProcessor(UnitFormatProcessor): - pass + or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \ + or (self.symbol is not None and self.symbol.lower().startswith(prefix)) + + +class UnknownUnit(NamedUnit): + """A unit for an unknown quantity + + While this library attempts to handle all known SI units, it is + likely that users will want to express quantities of arbitrary + units (for example, calculating donuts per person for a meeting). + The arbitrary unit allows for these unforseeable quantities.""" + + def __init__(self, + numerator: str | list[str] | dict[str, int | float], + denominator: None | list[str] | dict[str, int | float] = None): + if numerator is None: + return TypeError + self._numerator = UnknownUnit._parse_arg(numerator) + self._denominator = UnknownUnit._parse_arg(denominator) + self._unit = NamedUnit(1, Dimensions(), "") # Unitless + + super().__init__(si_scaling_factor=1, dimensions=self._unit.dimensions, symbol=self._name()) + + @staticmethod + def _parse_arg(arg: str | list[str] | dict[str, int | float] | None) -> dict[str, int | float]: + """Parse the different possibilities for constructor arguments + + Both the numerator and the denominator could be a string, a + list of strings, or a dict. Parse any of these values into a + dictionary of names and powers. + + """ + match arg: + case None: + return {} + case str(): + return {UnknownUnit._valid_name(arg): 1} + case list(): + result: dict[str, int | float] = {} + for key in arg: + if key in result: + result[key] += 1 + else: + UnknownUnit._valid_name(key) + result[key] = 1 + return result + case dict(): + for key in arg: + UnknownUnit._valid_name(key) + return arg + case _: + raise TypeError + + @staticmethod + def _valid_name(name: str) -> str: + """Confirms that the name of a unit is appropriate + + This mostly confirms that the unit does not contain math + operators that would act on other units, like / or ^ + """ + + if re.search(r"[*/^\s]", name): + raise RuntimeError(f'Unit name "{name}" contains invalid characters (*, /, ^, or whitespace)') + + return name + + def _name(self): + num = [] + for key, value in self._numerator.items(): + if value == 1: + num.append(key) + else: + num.append(f"{key}^{value}") + den = [] + for key, value in self._denominator.items(): + den.append(f"{key}^{-value}") + num.sort() + den.sort() + return " ".join(num + den) + + def __eq__(self, other): + match other: + case UnknownUnit(): + return self._numerator == other._numerator and self._denominator == other._denominator and self._unit == other._unit + case Unit(): + return not self._numerator and not self._denominator and self._unit == other + case _: + return False + + def __mul__(self: Self, other: "Unit"): + match other: + case UnknownUnit(): + num = dict(self._numerator) + for key in other._numerator: + if key in num: + num[key] += other._numerator[key] + else: + num[key] = other._numerator[key] + den = dict(self._denominator) + for key in other._denominator: + if key in den: + den[key] += other._denominator[key] + else: + den[key] = other._denominator[key] + result = UnknownUnit(num, den) + result._unit *= other._unit + return result._reduce() + case NamedUnit() | Unit() | int() | float(): + result = UnknownUnit(self._numerator, self._denominator) + result._unit *= other + return result + case _: + return NotImplemented + + def __rmul__(self: Self, other): + return self * other + + def __truediv__(self: Self, other: "Unit") -> "UnknownUnit": + match other: + case UnknownUnit(): + num = dict(self._numerator) + for key in other._denominator: + if key in num: + num[key] += other._denominator[key] + else: + num[key] = other._denominator[key] + den = dict(self._denominator) + for key in other._numerator: + if key in den: + den[key] += other._numerator[key] + else: + den[key] = other._numerator[key] + result = UnknownUnit(num, den) + result._unit /= other._unit + return result._reduce() + case NamedUnit() | Unit() | int() | float(): + result = UnknownUnit(self._numerator, self._denominator) + result._unit /= other + return result + case _: + return NotImplemented + + def __rtruediv__(self: Self, other: "Unit") -> "UnknownUnit": + return (self/other) ** -1 + + def __pow__(self, power: int | float) -> "UnknownUnit": + match power: + case int() | float(): + num = {key: value * power for key, value in self._numerator.items()} + den = {key: value * power for key, value in self._denominator.items()} + if power < 0: + num, den = den, num + num = {k: -v for k,v in num.items()} + den = {k: -v for k,v in den.items()} + + result = UnknownUnit(num, den) + result._unit = self._unit ** power + return result + case _: + return NotImplemented + + def equivalent(self: Self, other: "Unit"): + match other: + case UnknownUnit(): + return self._unit.equivalent(other._unit) and sorted(self._numerator) == sorted(other._numerator) and sorted(self._denominator) == sorted(other._denominator) + case _: + return False + + def _reduce(self): + """Remove redundant units""" + for k in self._denominator: + if k in self._numerator: + common = min(self._numerator[k], self._denominator[k]) + self._numerator[k] -= common + self._denominator[k] -= common + dead_nums = [k for k in self._numerator if self._numerator[k] == 0] + for k in dead_nums: + del self._numerator[k] + dead_dens = [k for k in self._denominator if self._denominator[k] == 0] + for k in dead_dens: + del self._denominator[k] + return self + + def __str__(self): + result = self._name() + if type(self._unit) is NamedUnit and self._unit.name.strip(): + result += f" {self._unit.name.strip()}" + if type(self._unit) is Unit and str(self._unit).strip(): + result += f" {str(self._unit).strip()}" + return result + + def __repr__(self): + return str(self) + class UnitGroup: """ A group of units that all have the same dimensionality """ diff --git a/sasdata/quantities/accessors.py b/sasdata/quantities/accessors.py index d23268544..159d33d49 100644 --- a/sasdata/quantities/accessors.py +++ b/sasdata/quantities/accessors.py @@ -9479,6 +9479,14 @@ def radians(self) -> T: else: return quantity.in_units_of(units.radians) + @property + def rotations(self) -> T: + quantity = self.quantity + if quantity is None: + return None + else: + return quantity.in_units_of(units.rotations) + class SolidangleAccessor[T](QuantityAccessor[T]): diff --git a/sasdata/quantities/units.py b/sasdata/quantities/units.py index fe840ab85..c4ca8f12f 100644 --- a/sasdata/quantities/units.py +++ b/sasdata/quantities/units.py @@ -82,14 +82,13 @@ # Included from _units_base.py # -from collections.abc import Sequence -from dataclasses import dataclass +import re from fractions import Fraction from typing import Self import numpy as np -from sasdata.quantities.unicode_superscript import int_as_unicode_superscript +from sasdata.quantities.unicode_superscript import int_as_unicode_superscript # type: ignore[import-untyped] class DimensionError(Exception): @@ -196,15 +195,15 @@ def __pow__(self, power: int | float): (self.moles_hint * numerator) // denominator, (self.angle_hint * numerator) // denominator) - def __eq__(self: Self, other: Self): + def __eq__(self: Self, other: object) -> bool: if isinstance(other, Dimensions): - return (self.length == other.length and - self.time == other.time and - self.mass == other.mass and - self.current == other.current and - self.temperature == other.temperature and - self.moles_hint == other.moles_hint and - self.angle_hint == other.angle_hint) + return (self.length == other.length + and self.time == other.time + and self.mass == other.mass + and self.current == other.current + and self.temperature == other.temperature + and self.moles_hint == other.moles_hint + and self.angle_hint == other.angle_hint) return NotImplemented @@ -295,9 +294,6 @@ def __init__(self, self.scale = si_scaling_factor self.dimensions = dimensions - def _components(self, tokens: Sequence["UnitToken"]): - pass - def __mul__(self: Self, other: "Unit"): if isinstance(other, Unit): return Unit(self.scale * other.scale, self.dimensions * other.dimensions) @@ -331,17 +327,15 @@ def __pow__(self, power: int | float): def equivalent(self: Self, other: "Unit"): return self.dimensions == other.dimensions - def __eq__(self: Self, other: "Unit"): - return self.equivalent(other) and np.abs(np.log(self.scale/other.scale)) < 1e-5 + def __eq__(self: Self, other: object) -> bool: + if isinstance(other, Unit): + return self.equivalent(other) and np.abs(np.log(self.scale/other.scale)) < 1e-5 + return False def si_equivalent(self): """ Get the SI unit corresponding to this unit""" return Unit(1, self.dimensions) - def _format_unit(self, format_process: list["UnitFormatProcessor"]): - for processor in format_process: - pass - def __repr__(self): if self.scale == 1: # We're in SI @@ -350,9 +344,6 @@ def __repr__(self): else: return f"Unit[{self.scale}, {self.dimensions}]" - @staticmethod - def parse(unit_string: str) -> "Unit": - pass class NamedUnit(Unit): """ Units, but they have a name, and a symbol @@ -393,57 +384,204 @@ def __eq__(self, other): case _: return False - def startswith(self, prefix: str) -> bool: """Check if any representation of the unit begins with the prefix string""" prefix = prefix.lower() return (self.name is not None and self.name.lower().startswith(prefix)) \ - or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \ - or (self.symbol is not None and self.symbol.lower().startswith(prefix)) + or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \ + or (self.symbol is not None and self.symbol.lower().startswith(prefix)) -# -# Parsing plan: -# Require unknown amounts of units to be explicitly positive or negative? -# -# +class UnknownUnit(NamedUnit): + """A unit for an unknown quantity + While this library attempts to handle all known SI units, it is + likely that users will want to express quantities of arbitrary + units (for example, calculating donuts per person for a meeting). + The arbitrary unit allows for these unforseeable quantities.""" + + def __init__(self, + numerator: str | list[str] | dict[str, int | float], + denominator: None | list[str] | dict[str, int | float] = None): + if numerator is None: + return TypeError + self._numerator = UnknownUnit._parse_arg(numerator) + self._denominator = UnknownUnit._parse_arg(denominator) + self._unit = NamedUnit(1, Dimensions(), "") # Unitless + + super().__init__(si_scaling_factor=1, dimensions=self._unit.dimensions, symbol=self._name()) + + @staticmethod + def _parse_arg(arg: str | list[str] | dict[str, int | float] | None) -> dict[str, int | float]: + """Parse the different possibilities for constructor arguments + + Both the numerator and the denominator could be a string, a + list of strings, or a dict. Parse any of these values into a + dictionary of names and powers. + + """ + match arg: + case None: + return {} + case str(): + return {UnknownUnit._valid_name(arg): 1} + case list(): + result: dict[str, int | float] = {} + for key in arg: + if key in result: + result[key] += 1 + else: + UnknownUnit._valid_name(key) + result[key] = 1 + return result + case dict(): + for key in arg: + UnknownUnit._valid_name(key) + return arg + case _: + raise TypeError + + @staticmethod + def _valid_name(name: str) -> str: + """Confirms that the name of a unit is appropriate + + This mostly confirms that the unit does not contain math + operators that would act on other units, like / or ^ + """ + + if re.search(r"[*/^\s]", name): + raise RuntimeError(f'Unit name "{name}" contains invalid characters (*, /, ^, or whitespace)') + + return name + + def _name(self): + num = [] + for key, value in self._numerator.items(): + if value == 1: + num.append(key) + else: + num.append(f"{key}^{value}") + den = [] + for key, value in self._denominator.items(): + den.append(f"{key}^{-value}") + num.sort() + den.sort() + return " ".join(num + den) + + def __eq__(self, other): + match other: + case UnknownUnit(): + return self._numerator == other._numerator and self._denominator == other._denominator and self._unit == other._unit + case Unit(): + return not self._numerator and not self._denominator and self._unit == other + case _: + return False + + def __mul__(self: Self, other: "Unit"): + match other: + case UnknownUnit(): + num = dict(self._numerator) + for key in other._numerator: + if key in num: + num[key] += other._numerator[key] + else: + num[key] = other._numerator[key] + den = dict(self._denominator) + for key in other._denominator: + if key in den: + den[key] += other._denominator[key] + else: + den[key] = other._denominator[key] + result = UnknownUnit(num, den) + result._unit *= other._unit + return result._reduce() + case NamedUnit() | Unit() | int() | float(): + result = UnknownUnit(self._numerator, self._denominator) + result._unit *= other + return result + case _: + return NotImplemented + + def __rmul__(self: Self, other): + return self * other + + def __truediv__(self: Self, other: "Unit") -> "UnknownUnit": + match other: + case UnknownUnit(): + num = dict(self._numerator) + for key in other._denominator: + if key in num: + num[key] += other._denominator[key] + else: + num[key] = other._denominator[key] + den = dict(self._denominator) + for key in other._numerator: + if key in den: + den[key] += other._numerator[key] + else: + den[key] = other._numerator[key] + result = UnknownUnit(num, den) + result._unit /= other._unit + return result._reduce() + case NamedUnit() | Unit() | int() | float(): + result = UnknownUnit(self._numerator, self._denominator) + result._unit /= other + return result + case _: + return NotImplemented + + def __rtruediv__(self: Self, other: "Unit") -> "UnknownUnit": + return (self/other) ** -1 + + def __pow__(self, power: int | float) -> "UnknownUnit": + match power: + case int() | float(): + num = {key: value * power for key, value in self._numerator.items()} + den = {key: value * power for key, value in self._denominator.items()} + if power < 0: + num, den = den, num + num = {k: -v for k,v in num.items()} + den = {k: -v for k,v in den.items()} + + result = UnknownUnit(num, den) + result._unit = self._unit ** power + return result + case _: + return NotImplemented + + def equivalent(self: Self, other: "Unit"): + match other: + case UnknownUnit(): + return self._unit.equivalent(other._unit) and sorted(self._numerator) == sorted(other._numerator) and sorted(self._denominator) == sorted(other._denominator) + case _: + return False + + def _reduce(self): + """Remove redundant units""" + for k in self._denominator: + if k in self._numerator: + common = min(self._numerator[k], self._denominator[k]) + self._numerator[k] -= common + self._denominator[k] -= common + dead_nums = [k for k in self._numerator if self._numerator[k] == 0] + for k in dead_nums: + del self._numerator[k] + dead_dens = [k for k in self._denominator if self._denominator[k] == 0] + for k in dead_dens: + del self._denominator[k] + return self + + def __str__(self): + result = self._name() + if type(self._unit) is NamedUnit and self._unit.name.strip(): + result += f" {self._unit.name.strip()}" + if type(self._unit) is Unit and str(self._unit).strip(): + result += f" {str(self._unit).strip()}" + return result + + def __repr__(self): + return str(self) -@dataclass -class ProcessedUnitToken: - """ Mid processing representation of formatted units """ - base_string: str - exponent_string: str - latex_exponent_string: str - exponent: int - -class UnitFormatProcessor: - """ Represents a step in the unit processing pipeline""" - def apply(self, scale, dimensions) -> tuple[ProcessedUnitToken, float, Dimensions]: - """ This will be called to deal with each processing stage""" - -class RequiredUnitFormatProcessor(UnitFormatProcessor): - """ This unit is required to exist in the formatting """ - def __init__(self, unit: Unit, power: int = 1): - self.unit = unit - self.power = power - def apply(self, scale, dimensions) -> tuple[float, Dimensions, ProcessedUnitToken]: - new_scale = scale / (self.unit.scale * self.power) - new_dimensions = self.unit.dimensions / (dimensions**self.power) - token = ProcessedUnitToken(self.unit, self.power) - - return new_scale, new_dimensions, token -class GreedyAbsDimensionUnitFormatProcessor(UnitFormatProcessor): - """ This processor minimises the dimensionality of the unit by multiplying by as many - units of the specified type as needed """ - def __init__(self, unit: Unit): - self.unit = unit - - def apply(self, scale, dimensions) -> tuple[ProcessedUnitToken, float, Dimensions]: - pass - -class GreedyAbsDimensionUnitFormatProcessor(UnitFormatProcessor): - pass class UnitGroup: """ A group of units that all have the same dimensionality """ diff --git a/test/quantities/utest_units.py b/test/quantities/utest_units.py index 3bc775313..c0d11b81a 100644 --- a/test/quantities/utest_units.py +++ b/test/quantities/utest_units.py @@ -1,72 +1,132 @@ import math -import sasdata.quantities.units as units -from sasdata.quantities.units import Unit - - -class EqualUnits: - def __init__(self, test_name: str, *units): - self.test_name = "Equality: " + test_name - self.units: list[Unit] = list(units) - - def run_test(self): - for i, unit_1 in enumerate(self.units): - for unit_2 in self.units[i + 1 :]: - assert unit_1.equivalent(unit_2), "Units should be equivalent" - assert unit_1 == unit_2, "Units should be equal" - - -class EquivalentButUnequalUnits: - def __init__(self, test_name: str, *units): - self.test_name = "Equivalence: " + test_name - self.units: list[Unit] = list(units) - - def run_test(self): - for i, unit_1 in enumerate(self.units): - for unit_2 in self.units[i + 1 :]: - assert unit_1.equivalent(unit_2), "Units should be equivalent" - assert unit_1 != unit_2, "Units should not be equal" - - -class DissimilarUnits: - def __init__(self, test_name: str, *units): - self.test_name = "Dissimilar: " + test_name - self.units: list[Unit] = list(units) - - def run_test(self): - for i, unit_1 in enumerate(self.units): - for unit_2 in self.units[i + 1 :]: - assert not unit_1.equivalent(unit_2), "Units should not be equivalent" - +import pytest -tests = [ - - EqualUnits("Pressure", - units.pascals, - units.newtons / units.meters ** 2, - units.micronewtons * units.millimeters ** -2), - - EqualUnits("Resistance", - units.ohms, - units.volts / units.amperes, - 1e-3/units.millisiemens), - - EquivalentButUnequalUnits("Angular frequency", - units.rotations / units.minutes, - units.degrees * units.hertz), - - EqualUnits("Angular frequency", - (units.rotations/units.minutes ), - (units.radians*units.hertz) * 2 * math.pi/60.0), - - DissimilarUnits("Frequency and Angular frequency", - (units.rotations/units.minutes), - (units.hertz)), - - -] - - -for test in tests: - print(test.test_name) - test.run_test() +import sasdata.quantities.units as units +from sasdata.quantities.units import UnknownUnit + +EQUAL_TERMS = { + "Pressure": [units.pascals, units.newtons / units.meters**2, units.micronewtons * units.millimeters**-2], + "Resistance": [units.ohms, units.volts / units.amperes, 1e-3 / units.millisiemens], + "Angular frequency": [(units.rotations / units.minutes), (units.radians * units.hertz) * 2 * math.pi / 60.0], + "Unknown Units": [UnknownUnit("Pizzas"), UnknownUnit(["Pizzas"])], + "Unknown Fractional Units": [ + UnknownUnit("Slices", denominator=["Pizzas"]), + UnknownUnit(["Slices"], denominator=["Pizzas"]), + ], + "Unknown Multiplication": [ + UnknownUnit("Pizzas") * UnknownUnit("People"), + UnknownUnit(["Pizzas", "People"]), + ], + "Unknown Multiplication with Units": [ + UnknownUnit("Pizzas") * units.meters, + units.meters * UnknownUnit(["Pizzas"]), + ], + "Unknown Power": [ + UnknownUnit(["Slices"], denominator=["Pizza"]) * UnknownUnit(["Slices"], denominator=["Pizza"]), + UnknownUnit(["Slices"], denominator=["Pizza"]) ** 2, + ], + "Unknown Fractional Power": [ + UnknownUnit(["Pizza", "Pizza", "Pizza"]), + UnknownUnit(["Pizza", "Pizza"]) ** 1.5, + ], + "Unknown Division": [ + UnknownUnit("Slices") / UnknownUnit("Pizza"), + UnknownUnit(["Slices"], denominator=["Pizza"]), + (1 / UnknownUnit("Pizza")) * UnknownUnit("Slices"), + 1 / (UnknownUnit("Pizza") / UnknownUnit("Slices")), + ], + "Unknown Complicated Math": [ + (UnknownUnit("Slices") / UnknownUnit("Person")) + / (UnknownUnit("Slices") / UnknownUnit("Pizzas")) + * UnknownUnit("Person"), + UnknownUnit("Pizzas"), + ], +} + + +@pytest.fixture(params=EQUAL_TERMS) +def equal_term(request): + return EQUAL_TERMS[request.param] + + +def test_unit_equality(equal_term): + for i, unit_1 in enumerate(equal_term): + for unit_2 in equal_term[i + 1 :]: + assert unit_1.equivalent(unit_2), "Units should be equivalent" + assert unit_1 == unit_2, "Units should be equal" + + +EQUIVALENT_TERMS = { + "Angular frequency": [units.rotations / units.minutes, units.degrees * units.hertz], +} + + +@pytest.fixture(params=EQUIVALENT_TERMS) +def equivalent_term(request): + return EQUIVALENT_TERMS[request.param] + + +def test_unit_equivalent(equivalent_term): + units = equivalent_term + for i, unit_1 in enumerate(units): + for unit_2 in units[i + 1 :]: + assert unit_1.equivalent(unit_2), "Units should be equivalent" + assert unit_1 != unit_2, "Units not should be equal" + + +DISSIMILAR_TERMS = { + "Frequency and Angular frequency": [(units.rotations / units.minutes), (units.hertz)], + "Different Unknown Units": [UnknownUnit("Pizzas"), UnknownUnit(["Donuts"])], + "Unknown Multiplication with Units": [ + UnknownUnit("Pizzas") * units.meters, + units.seconds * UnknownUnit(["Pizzas"]), + ], +} + + +@pytest.fixture(params=DISSIMILAR_TERMS) +def dissimilar_term(request): + return DISSIMILAR_TERMS[request.param] + + +def test_unit_dissimilar(dissimilar_term): + units = dissimilar_term + for i, unit_1 in enumerate(units): + for unit_2 in units[i + 1:]: + assert not unit_1.equivalent(unit_2), "Units should not be equivalent" + + +def test_unit_operations(): + pizza = UnknownUnit(["Pizza"]) + slice = UnknownUnit("Slice") + pineapple = UnknownUnit("Pineapple") + pie = UnknownUnit("Pie") + empty = UnknownUnit([]) + + with pytest.raises(RuntimeError): + UnknownUnit("a/b") + with pytest.raises(RuntimeError): + UnknownUnit(["a^b"]) + with pytest.raises(RuntimeError): + UnknownUnit({"a b": 1}) + with pytest.raises(RuntimeError): + UnknownUnit("a", {"a*b": 1}) + with pytest.raises(RuntimeError): + UnknownUnit("a", ["a^b"]) + with pytest.raises(RuntimeError): + UnknownUnit("a", "a/b") + + assert str(empty) == "" + + assert str(pizza) == "Pizza" + assert str(pizza * pineapple) == "Pineapple Pizza" + assert str(pizza * pizza) == "Pizza^2" + + assert str(1 / pizza) == "Pizza^-1" + assert str(1 / pizza / pineapple) == "Pineapple^-1 Pizza^-1" + assert str(slice / pizza) == "Slice Pizza^-1" + assert str(slice / pizza / pineapple) == "Slice Pineapple^-1 Pizza^-1" + assert str((slice / pizza) ** 2) == "Slice^2 Pizza^-2" + + assert str(pie**0.5) == "Pie^0.5" # A valid unit, because pie are square