diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 6e60fb83a..167145816 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -1186,6 +1186,42 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> List[Any]: return aggregator.aggregate(self) + def add_synthetic_event( + self, event_type_: EventType, event_factory_=None, **kwargs + ): + """ + Adds synthetic events of the specified type to the event dataset. This method analyzes the stream of + events and inserts synthetic events at the appropriate points within the dataset based on the event type. + + Args: + event_type_ (EventType): The type of event to generate. The supported event types are currently: + - `EventType.CARRY`: Generates carry events. + event_factory_ (Optional[EventFactory]): An optional event factory to create the events. If not provided, + a default event factory will be used. + **kwargs: Additional configuration parameters passed to the specific synthetic event generator class. + The expected parameters depend on the type of event being generated (e.g., `SyntheticCarryGenerator`). + + Raises: + KloppyError: If the provided `event_type_` is not supported or invalid. + + Example: + To generate a synthetic carry event: + add_synthetic_event(EventType.CARRY, event_factory=my_event_factory, min_length_meters=3, max_length_meters=60, max_duration=timedelta(seconds=10)) + """ + if event_type_ == EventType.CARRY: + from kloppy.domain.services.synthetic_event_generators.carry import ( + SyntheticCarryGenerator, + ) + + synthetic_event_generator = SyntheticCarryGenerator( + event_factory_, **kwargs + ) + else: + raise KloppyError( + f"Not possible to generate synthetic {event_type_}" + ) + return synthetic_event_generator.add_synthetic_event(self) + __all__ = [ "EnumQualifier", diff --git a/kloppy/domain/services/synthetic_event_generators/__init__.py b/kloppy/domain/services/synthetic_event_generators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py new file mode 100644 index 000000000..a57b3fa9e --- /dev/null +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -0,0 +1,140 @@ +from datetime import timedelta +from typing import Optional + +from kloppy.domain import ( + EventDataset, + BodyPart, + CarryResult, + Unit, + EventFactory, + PassEvent, + ShotEvent, + TakeOnEvent, + ClearanceEvent, + InterceptionEvent, + DuelEvent, + RecoveryEvent, + MiscontrolEvent, + GoalkeeperEvent, + GenericEvent, +) +from kloppy.domain.models.event import PressureEvent, SetPieceQualifier +from kloppy.domain.services.synthetic_event_generators.synthetic_event_generator import ( + SyntheticEventGenerator, +) + +VALID_EVENT = ( + PassEvent, + ShotEvent, + TakeOnEvent, + ClearanceEvent, + InterceptionEvent, + DuelEvent, + RecoveryEvent, + MiscontrolEvent, + GoalkeeperEvent, +) + + +class SyntheticCarryGenerator(SyntheticEventGenerator): + def __init__(self, event_factory: Optional[EventFactory] = None, **kwargs): + self.event_factory = event_factory or EventFactory() + self.min_length_meters = kwargs.get("min_length_meters") or 3 + self.max_length_meters = kwargs.get("max_length_meters") or 60 + self.max_duration = kwargs.get("max_duration") or timedelta(seconds=10) + + def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: + pitch = dataset.metadata.pitch_dimensions + + for idx, event in enumerate(dataset.events): + if not isinstance(event, VALID_EVENT): + continue + idx_plus = 1 + next_event = None + while idx + idx_plus < len(dataset.events): + next_event = dataset.events[idx + idx_plus] + + if isinstance(next_event, (GenericEvent, PressureEvent)): + idx_plus += 1 + continue + else: + break + + if not isinstance(next_event, VALID_EVENT): + continue + if not event.team.team_id == next_event.team.team_id: + continue + if next_event.get_qualifier_value(SetPieceQualifier) is not None: + continue + if hasattr(event, "end_coordinates"): + last_coord = event.end_coordinates + elif hasattr(event, "receiver_coordinates"): + last_coord = event.receiver_coordinates + else: + last_coord = event.coordinates + + new_coord = next_event.coordinates + + distance_meters = pitch.distance_between( + new_coord, last_coord, Unit.METERS + ) + # Not far enough + if distance_meters < self.min_length_meters: + continue + # Too far + if distance_meters > self.max_length_meters: + continue + + dt = next_event.timestamp - event.timestamp + # not same phase + if dt > self.max_duration: + continue + # not same period + if not event.period.id == next_event.period.id: + continue + + # not headed shot + if isinstance(next_event, ShotEvent) and any( + qualifier.name == "body_part" + and qualifier.value in [BodyPart.HEAD, BodyPart.HEAD_OTHER] + for qualifier in next_event.qualifiers or [] + ): + continue + + if ( + hasattr(event, "end_timestamp") + and event.end_timestamp is not None + ): + last_timestamp = event.end_timestamp + elif ( + hasattr(event, "receive_timestamp") + and event.receive_timestamp is not None + ): + last_timestamp = event.receive_timestamp + else: + last_timestamp = ( + event.timestamp + + (next_event.timestamp - event.timestamp) / 10 + ) + generic_event_args = { + "event_id": f"carry-{event.event_id}", + "coordinates": last_coord, + "team": next_event.team, + "player": next_event.player, + "ball_owning_team": next_event.ball_owning_team, + "ball_state": event.ball_state, + "period": next_event.period, + "timestamp": last_timestamp, + "raw_event": None, + } + carry_event_args = { + "result": CarryResult.COMPLETE, + "qualifiers": None, + "end_coordinates": new_coord, + "end_timestamp": next_event.timestamp, + } + new_carry = self.event_factory.build_carry( + **carry_event_args, **generic_event_args + ) + dataset.records.insert(idx + idx_plus, new_carry) + return dataset diff --git a/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py new file mode 100644 index 000000000..1b265116b --- /dev/null +++ b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + +from kloppy.domain import EventDataset + + +class SyntheticEventGenerator(ABC): + @abstractmethod + def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: + raise NotImplementedError diff --git a/kloppy/tests/test_synthetic_event_generator.py b/kloppy/tests/test_synthetic_event_generator.py new file mode 100644 index 000000000..e3310fd55 --- /dev/null +++ b/kloppy/tests/test_synthetic_event_generator.py @@ -0,0 +1,127 @@ +from datetime import timedelta + +from kloppy.domain import ( + EventType, + Unit, +) +from kloppy.utils import performance_logging +from kloppy import statsbomb, statsperform + + +class TestSyntheticEventGenerator: + """""" + + def _load_dataset_statsperform( + self, base_dir, base_filename="statsperform" + ): + return statsperform.load_event( + ma1_data=base_dir / f"files/{base_filename}_event_ma1.json", + ma3_data=base_dir / f"files/{base_filename}_event_ma3.json", + ) + + def _load_dataset_statsbomb( + self, base_dir, base_filename="statsbomb", event_types=None + ): + return statsbomb.load( + event_data=base_dir / f"files/{base_filename}_event.json", + lineup_data=base_dir / f"files/{base_filename}_lineup.json", + event_types=event_types, + ) + + def calculate_carry_accuracy( + self, real_carries, generated_carries, real_carries_with_min_length + ): + def is_match(real_carry, generated_carry): + return ( + real_carry.player + and generated_carry.player + and real_carry.player.player_id + == generated_carry.player.player_id + and real_carry.period == generated_carry.period + and abs(real_carry.timestamp - generated_carry.timestamp) + < timedelta(seconds=5) + ) + + true_positives = 0 + matched_real_carries = set() + for generated_carry in generated_carries: + for idx, real_carry in enumerate(real_carries): + if idx in matched_real_carries: + continue + if is_match(real_carry, generated_carry): + true_positives += 1 + matched_real_carries.add(idx) + break + + false_negatives = 0 + matched_generated_carries = set() + for real_carry in real_carries_with_min_length: + found_match = False + for idx, generated_carry in enumerate(generated_carries): + if idx in matched_generated_carries: + continue + if is_match(real_carry, generated_carry): + found_match = True + matched_generated_carries.add(idx) + break + if not found_match: + false_negatives += 1 + + false_positives = len(generated_carries) - true_positives + + accuracy = true_positives / ( + true_positives + false_positives + false_negatives + ) + + print("TP:", true_positives) + print("FP:", false_positives) + print("FN:", false_negatives) + print("accuracy:", accuracy) + + return accuracy + + def test_synthetic_carry_generator(self, base_dir): + dataset_with_carries = self._load_dataset_statsbomb(base_dir) + pitch = dataset_with_carries.metadata.pitch_dimensions + + min_length_meters = 3 + max_length_meters = 60 + max_duration = timedelta(seconds=10) + + all_statsbomb_caries = dataset_with_carries.find_all("carry") + all_qualifying_statsbomb_queries = [ + carry + for carry in all_statsbomb_caries + if ( + min_length_meters + <= pitch.distance_between( + carry.coordinates, carry.end_coordinates, Unit.METERS + ) + <= max_length_meters + and carry.end_timestamp - carry.timestamp < max_duration + ) + ] + + dataset = self._load_dataset_statsbomb( + base_dir, + event_types=[ + event.value for event in EventType if event.value != "CARRY" + ], + ) + + with performance_logging("generating synthetic events"): + dataset = dataset.add_synthetic_event( + EventType.CARRY, + min_length_meters=min_length_meters, + max_length_meters=max_length_meters, + max_duration=max_duration, + ) + all_carries = dataset.find_all("carry") + assert ( + self.calculate_carry_accuracy( + all_statsbomb_caries, + all_carries, + all_qualifying_statsbomb_queries, + ) + > 0.80 + )