From ae53381fa3b3e91f73c5abfd2d5b5eed3a223b5b Mon Sep 17 00:00:00 2001 From: lode-mgp Date: Tue, 7 Jan 2025 15:15:00 +0100 Subject: [PATCH 01/22] work in progress: deduce carries from event stream" --- kloppy/domain/models/event.py | 11 + .../services/event_deducers/__init__.py | 0 .../domain/services/event_deducers/carry.py | 229 ++++++++++++++++++ .../services/event_deducers/event_deducer.py | 10 + kloppy/tests/test_event_deducer.py | 37 +++ 5 files changed, 287 insertions(+) create mode 100644 kloppy/domain/services/event_deducers/__init__.py create mode 100644 kloppy/domain/services/event_deducers/carry.py create mode 100644 kloppy/domain/services/event_deducers/event_deducer.py create mode 100644 kloppy/tests/test_event_deducer.py diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 6e60fb83a..d4a191be9 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -1186,6 +1186,17 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> List[Any]: return aggregator.aggregate(self) + def add_deduced_event(self, event_type_: EventType): + if event_type_ == EventType.CARRY: + from kloppy.domain.services.event_deducers.carry import ( + CarryDeducer, + ) + + deducer = CarryDeducer() + else: + raise KloppyError(f"Not possible to deduce {event_type_}") + deducer.deduce(self) + __all__ = [ "EnumQualifier", diff --git a/kloppy/domain/services/event_deducers/__init__.py b/kloppy/domain/services/event_deducers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py new file mode 100644 index 000000000..4602a3ebc --- /dev/null +++ b/kloppy/domain/services/event_deducers/carry.py @@ -0,0 +1,229 @@ +import bisect +import math +from datetime import timedelta +from typing import List, NamedTuple, Union + +import pandas as pd + +from kloppy.domain import ( + EventDataset, + Player, + Time, + PositionType, + Event, + EventType, + BodyPart, + CarryResult, + CarryEvent, + GenericEvent, + Point, + EventFactory, + Dimension, + Unit, +) +from kloppy.domain.services.event_deducers.event_deducer import ( + EventDatasetDeduducer, +) + + +class CarryDeducer(EventDatasetDeduducer): + def deduce_old(self, dataset: EventDataset) -> List[Event]: + # TODO: config + min_dribble_length = 2 + max_dribble_length = 50 + max_dribble_duration = timedelta(seconds=5) + + events = dataset.to_df() + next_actions = events.shift(-1, fill_value=None) + + same_team = events.team_id == next_actions.team_id + not_offensive_foul = same_team & ( + next_actions.event_type != EventType.FOUL_COMMITTED + ) + + not_headed_shot = (next_actions.event_type != EventType.SHOT) & ( + ~next_actions.body_part_type.isin( + [BodyPart.HEAD, BodyPart.HEAD_OTHER] + ) + ) + + dx = events.end_coordinates_x - next_actions.coordinates_x + dy = events.end_coordinates_y - next_actions.coordinates_y + far_enough = dx**2 + dy**2 >= min_dribble_length**2 + not_too_far = dx**2 + dy**2 <= max_dribble_length**2 + + dt = next_actions.timestamp - events.timestamp + same_phase = dt < max_dribble_duration + same_period = events.period_id == next_actions.period_id + + dribble_idx = ( + same_team + & far_enough + & not_too_far + & same_phase + & same_period + & not_offensive_foul + & not_headed_shot + ) + + dribbles = pd.DataFrame() + prev = events[dribble_idx] + nex = next_actions[dribble_idx] + dribbles["period_id"] = nex.period_id + dribbles["event_id"] = prev.event_id + 0.1 + dribbles["timestamp"] = (prev.timestamp + nex.timestamp) / 2 + if "timestamp" in events.columns: + dribbles["timestamp"] = nex.timestamp + dribbles["team_id"] = nex.team_id + dribbles["player_id"] = nex.player_id + dribbles["coordinates_x"] = prev.end_coordinates_x + dribbles["coordinates_y"] = prev.end_coordinates_y + dribbles["end_coordinates_x"] = nex.coordinates_x + dribbles["end_coordinates_y"] = nex.coordinates_y + dribbles["body_part_type"] = BodyPart.RIGHT_FOOT # TODO: fix + dribbles["event_type"] = EventType.CARRY + dribbles["result"] = CarryResult.COMPLETE + + new_carries: List[Event] = [] + + # Iterate over the rows of the dribbles DataFrame and create new Event objects + for _, row in dribbles.iterrows(): + new_event = CarryEvent( + event_id=row["event_id"], + timestamp=row["timestamp"], + team_id=row["team_id"], + player_id=row["player_id"], + coordinates_x=row["coordinates_x"], + coordinates_y=row["coordinates_y"], + end_coordinates_x=row["end_coordinates_x"], + end_coordinates_y=row["end_coordinates_y"], + body_part_type=row["body_part_type"], + event_type=row["event_type"], + result=row["result"], + ) + # Append the new Event to the new_carries list + new_carries.append(new_event) + + return events.values.tolist() + + def deduce(self, dataset: EventDataset): + event_factory = EventFactory() + # TODO: config + min_dribble_length = 3 + max_dribble_length = 60 + unit = Unit("m") + min_dribble_length = unit.convert( + dataset.metadata.coordinate_system.pitch_dimensions.unit, + min_dribble_length, + ) + max_dribble_length = unit.convert( + dataset.metadata.coordinate_system.pitch_dimensions.unit, + max_dribble_length, + ) + + max_dribble_duration = timedelta(seconds=10) + new_carries = [] + for idx, event in enumerate(dataset.events): + if isinstance(event, GenericEvent): + continue + if event.event_type in [ + EventType.FOUL_COMMITTED, + EventType.CARD, + EventType.SUBSTITUTION, + EventType.FORMATION_CHANGE, + EventType.CLEARANCE, + ]: + continue + idx_sum = 1 + generic_next_event = True + while idx + idx_sum < len(dataset.events) and generic_next_event: + next_event = dataset.events[idx + idx_sum] + + if isinstance(next_event, GenericEvent): + idx += 1 + continue + else: + generic_next_event = False + if not event.team.team_id == next_event.team.team_id: + continue + + if next_event.event_type in [ + EventType.FOUL_COMMITTED, + EventType.CARD, + EventType.SUBSTITUTION, + EventType.FORMATION_CHANGE, + ]: + continue + # not headed shot + if ( + (hasattr(next_event, "body_part")) + and (next_event.event_type == EventType.SHOT) + and ( + next_event.body_part.type.isin( + [BodyPart.HEAD, BodyPart.HEAD_OTHER] + ) + ) + ): + continue + + if hasattr(event, "end_coordinates"): + last_coord = event.end_coordinates + elif hasattr(event, "receiver_coordinates"): + # Handle the case where the attribute doesn't exist + last_coord = event.receiver_coordinates + else: + last_coord = event.coordinates + + new_coord = next_event.coordinates + + # Not far enough + if new_coord.distance_to(last_coord) < min_dribble_length: + continue + # Too far + if new_coord.distance_to(last_coord) > max_dribble_length: + continue + + dt = next_event.timestamp - event.timestamp + # not same phase + if dt > max_dribble_duration: + continue + # not same period + if not event.period.id == next_event.period.id: + continue + + if hasattr(event, "end_timestamp"): + last_timestamp = event.end_timestamp + else: + last_timestamp = ( + event.timestamp + + (next_event.timestamp - event.timestamp) / 10 + ) + + generic_event_args = { + "event_id": 1, # TODO: generate event id + "coordinates": last_coord, + "team": next_event.team, + "player": next_event.player, + "ball_owning_team": next_event.ball_owning_team, + "ball_state": event.ball_state, + "period": next_event.period, + "timestamp": last_timestamp, + "raw_event": {}, + } + carry_event_args = { + "result": CarryResult.COMPLETE, + "qualifiers": None, + "end_coordinates": new_coord, + "end_timestamp": next_event.timestamp, + } + new_carry = event_factory.build_carry( + **carry_event_args, **generic_event_args + ) + new_carries.append(new_carry) + + for new_carry in new_carries: + pos = bisect.bisect_left( + [e.time for e in dataset.events], new_carry.time + ) + dataset.records.insert(pos, new_carry) + print(f"total carries: {len(new_carries)}/{len(dataset.events)}") diff --git a/kloppy/domain/services/event_deducers/event_deducer.py b/kloppy/domain/services/event_deducers/event_deducer.py new file mode 100644 index 000000000..08c4d92ff --- /dev/null +++ b/kloppy/domain/services/event_deducers/event_deducer.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod +from typing import List, NamedTuple + +from kloppy.domain import EventDataset, Event + + +class EventDatasetDeduducer(ABC): + @abstractmethod + def deduce(self, dataset: EventDataset) -> EventDataset: + raise NotImplementedError diff --git a/kloppy/tests/test_event_deducer.py b/kloppy/tests/test_event_deducer.py new file mode 100644 index 000000000..192cff18a --- /dev/null +++ b/kloppy/tests/test_event_deducer.py @@ -0,0 +1,37 @@ +from itertools import groupby + +from kloppy.domain import ( + EventType, + Event, + EventDataset, + FormationType, + CarryEvent, +) +from kloppy.domain.services.state_builder.builder import StateBuilder +from kloppy.utils import performance_logging +from kloppy import statsbomb, statsperform + + +class TestStateBuilder: + """""" + + def _load_dataset(self, base_dir, base_filename="statsperform"): + return statsperform.load_event( + ma1_data=base_dir / f"files/{base_filename}_event_ma1.json", + ma3_data=base_dir / f"files/{base_filename}_event_ma3.json", + coordinates="statsbomb", + ) + + def test_carry_deducer(self, base_dir): + dataset = self._load_dataset(base_dir) + + with performance_logging("deduce_events"): + dataset.add_deduced_event(EventType.CARRY) + carry = dataset.find("carry") + index = dataset.events.index(carry) + print(carry) + # Assert end location is equal to start location of next action + assert carry.end_coordinates == dataset.events[index + 1].coordinates + assert carry.player == dataset.events[index + 1].player + + print(dataset.to_df()[:40].to_string()) From 7cc848c6925f4c76a6f62f0d47ac92c47ad5e4ab Mon Sep 17 00:00:00 2001 From: lodevt Date: Tue, 7 Jan 2025 15:25:09 +0100 Subject: [PATCH 02/22] clean up --- kloppy/tests/test_event_deducer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/kloppy/tests/test_event_deducer.py b/kloppy/tests/test_event_deducer.py index 192cff18a..cce9fd4de 100644 --- a/kloppy/tests/test_event_deducer.py +++ b/kloppy/tests/test_event_deducer.py @@ -12,7 +12,7 @@ from kloppy import statsbomb, statsperform -class TestStateBuilder: +class TestEventDeducer: """""" def _load_dataset(self, base_dir, base_filename="statsperform"): @@ -29,9 +29,6 @@ def test_carry_deducer(self, base_dir): dataset.add_deduced_event(EventType.CARRY) carry = dataset.find("carry") index = dataset.events.index(carry) - print(carry) # Assert end location is equal to start location of next action assert carry.end_coordinates == dataset.events[index + 1].coordinates assert carry.player == dataset.events[index + 1].player - - print(dataset.to_df()[:40].to_string()) From f2116fef7f7fd8e6c80720739bcbdcec4bb55ec2 Mon Sep 17 00:00:00 2001 From: lodevt Date: Tue, 7 Jan 2025 15:27:01 +0100 Subject: [PATCH 03/22] remove unused imports --- kloppy/domain/services/event_deducers/carry.py | 8 +------- kloppy/domain/services/event_deducers/event_deducer.py | 4 +--- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index 4602a3ebc..a449ade26 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -1,24 +1,18 @@ import bisect -import math from datetime import timedelta -from typing import List, NamedTuple, Union +from typing import List import pandas as pd from kloppy.domain import ( EventDataset, - Player, - Time, - PositionType, Event, EventType, BodyPart, CarryResult, CarryEvent, GenericEvent, - Point, EventFactory, - Dimension, Unit, ) from kloppy.domain.services.event_deducers.event_deducer import ( diff --git a/kloppy/domain/services/event_deducers/event_deducer.py b/kloppy/domain/services/event_deducers/event_deducer.py index 08c4d92ff..459c75b77 100644 --- a/kloppy/domain/services/event_deducers/event_deducer.py +++ b/kloppy/domain/services/event_deducers/event_deducer.py @@ -1,7 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, NamedTuple - -from kloppy.domain import EventDataset, Event +from kloppy.domain import EventDataset class EventDatasetDeduducer(ABC): From c45106d2a6545080462967431355b4ce189a0dc0 Mon Sep 17 00:00:00 2001 From: lodevt Date: Tue, 7 Jan 2025 15:30:42 +0100 Subject: [PATCH 04/22] more clean up --- .../domain/services/event_deducers/carry.py | 84 +------------------ 1 file changed, 3 insertions(+), 81 deletions(-) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index a449ade26..29c3703aa 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -21,90 +21,14 @@ class CarryDeducer(EventDatasetDeduducer): - def deduce_old(self, dataset: EventDataset) -> List[Event]: - # TODO: config - min_dribble_length = 2 - max_dribble_length = 50 - max_dribble_duration = timedelta(seconds=5) - - events = dataset.to_df() - next_actions = events.shift(-1, fill_value=None) - - same_team = events.team_id == next_actions.team_id - not_offensive_foul = same_team & ( - next_actions.event_type != EventType.FOUL_COMMITTED - ) - - not_headed_shot = (next_actions.event_type != EventType.SHOT) & ( - ~next_actions.body_part_type.isin( - [BodyPart.HEAD, BodyPart.HEAD_OTHER] - ) - ) - - dx = events.end_coordinates_x - next_actions.coordinates_x - dy = events.end_coordinates_y - next_actions.coordinates_y - far_enough = dx**2 + dy**2 >= min_dribble_length**2 - not_too_far = dx**2 + dy**2 <= max_dribble_length**2 - - dt = next_actions.timestamp - events.timestamp - same_phase = dt < max_dribble_duration - same_period = events.period_id == next_actions.period_id - - dribble_idx = ( - same_team - & far_enough - & not_too_far - & same_phase - & same_period - & not_offensive_foul - & not_headed_shot - ) - - dribbles = pd.DataFrame() - prev = events[dribble_idx] - nex = next_actions[dribble_idx] - dribbles["period_id"] = nex.period_id - dribbles["event_id"] = prev.event_id + 0.1 - dribbles["timestamp"] = (prev.timestamp + nex.timestamp) / 2 - if "timestamp" in events.columns: - dribbles["timestamp"] = nex.timestamp - dribbles["team_id"] = nex.team_id - dribbles["player_id"] = nex.player_id - dribbles["coordinates_x"] = prev.end_coordinates_x - dribbles["coordinates_y"] = prev.end_coordinates_y - dribbles["end_coordinates_x"] = nex.coordinates_x - dribbles["end_coordinates_y"] = nex.coordinates_y - dribbles["body_part_type"] = BodyPart.RIGHT_FOOT # TODO: fix - dribbles["event_type"] = EventType.CARRY - dribbles["result"] = CarryResult.COMPLETE - - new_carries: List[Event] = [] - - # Iterate over the rows of the dribbles DataFrame and create new Event objects - for _, row in dribbles.iterrows(): - new_event = CarryEvent( - event_id=row["event_id"], - timestamp=row["timestamp"], - team_id=row["team_id"], - player_id=row["player_id"], - coordinates_x=row["coordinates_x"], - coordinates_y=row["coordinates_y"], - end_coordinates_x=row["end_coordinates_x"], - end_coordinates_y=row["end_coordinates_y"], - body_part_type=row["body_part_type"], - event_type=row["event_type"], - result=row["result"], - ) - # Append the new Event to the new_carries list - new_carries.append(new_event) - - return events.values.tolist() - def deduce(self, dataset: EventDataset): event_factory = EventFactory() + # TODO: config min_dribble_length = 3 max_dribble_length = 60 + max_dribble_duration = timedelta(seconds=10) + unit = Unit("m") min_dribble_length = unit.convert( dataset.metadata.coordinate_system.pitch_dimensions.unit, @@ -115,7 +39,6 @@ def deduce(self, dataset: EventDataset): max_dribble_length, ) - max_dribble_duration = timedelta(seconds=10) new_carries = [] for idx, event in enumerate(dataset.events): if isinstance(event, GenericEvent): @@ -163,7 +86,6 @@ def deduce(self, dataset: EventDataset): if hasattr(event, "end_coordinates"): last_coord = event.end_coordinates elif hasattr(event, "receiver_coordinates"): - # Handle the case where the attribute doesn't exist last_coord = event.receiver_coordinates else: last_coord = event.coordinates From db8ca7e1ca2cac23c2fcc59f92727503a17848a5 Mon Sep 17 00:00:00 2001 From: lodevt Date: Tue, 7 Jan 2025 16:04:01 +0100 Subject: [PATCH 05/22] generate event id and some refactoring --- .../domain/services/event_deducers/carry.py | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index 29c3703aa..c7191601a 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -1,4 +1,5 @@ import bisect +import uuid from datetime import timedelta from typing import List @@ -25,36 +26,39 @@ def deduce(self, dataset: EventDataset): event_factory = EventFactory() # TODO: config - min_dribble_length = 3 - max_dribble_length = 60 - max_dribble_duration = timedelta(seconds=10) + min_carry_length = 3 + max_carry_length = 60 + max_carry_duration = timedelta(seconds=10) unit = Unit("m") - min_dribble_length = unit.convert( + min_carry_length = unit.convert( dataset.metadata.coordinate_system.pitch_dimensions.unit, - min_dribble_length, + min_carry_length, ) - max_dribble_length = unit.convert( + max_carry_length = unit.convert( dataset.metadata.coordinate_system.pitch_dimensions.unit, - max_dribble_length, + max_carry_length, ) new_carries = [] + valid_event_types = [ + EventType.PASS, + EventType.SHOT, + EventType.TAKE_ON, + EventType.INTERCEPTION, + EventType.DUEL, + EventType.RECOVERY, + EventType.MISCONTROL, + EventType.GOALKEEPER, + EventType.PRESSURE, + ] for idx, event in enumerate(dataset.events): - if isinstance(event, GenericEvent): + if event.event_type not in valid_event_types: continue - if event.event_type in [ - EventType.FOUL_COMMITTED, - EventType.CARD, - EventType.SUBSTITUTION, - EventType.FORMATION_CHANGE, - EventType.CLEARANCE, - ]: - continue - idx_sum = 1 + idx_plus = 1 generic_next_event = True - while idx + idx_sum < len(dataset.events) and generic_next_event: - next_event = dataset.events[idx + idx_sum] + while idx + idx_plus < len(dataset.events) and generic_next_event: + next_event = dataset.events[idx + idx_plus] if isinstance(next_event, GenericEvent): idx += 1 @@ -64,12 +68,7 @@ def deduce(self, dataset: EventDataset): if not event.team.team_id == next_event.team.team_id: continue - if next_event.event_type in [ - EventType.FOUL_COMMITTED, - EventType.CARD, - EventType.SUBSTITUTION, - EventType.FORMATION_CHANGE, - ]: + if next_event.event_type not in valid_event_types: continue # not headed shot if ( @@ -93,15 +92,15 @@ def deduce(self, dataset: EventDataset): new_coord = next_event.coordinates # Not far enough - if new_coord.distance_to(last_coord) < min_dribble_length: + if new_coord.distance_to(last_coord) < min_carry_length: continue # Too far - if new_coord.distance_to(last_coord) > max_dribble_length: + if new_coord.distance_to(last_coord) > max_carry_length: continue dt = next_event.timestamp - event.timestamp # not same phase - if dt > max_dribble_duration: + if dt > max_carry_duration: continue # not same period if not event.period.id == next_event.period.id: @@ -116,7 +115,7 @@ def deduce(self, dataset: EventDataset): ) generic_event_args = { - "event_id": 1, # TODO: generate event id + "event_id": str(uuid.uuid4()), "coordinates": last_coord, "team": next_event.team, "player": next_event.player, From a64a0710859d5a5021557ec4912ab5df5a704067 Mon Sep 17 00:00:00 2001 From: lodevt Date: Tue, 7 Jan 2025 16:10:56 +0100 Subject: [PATCH 06/22] remove print statement --- kloppy/domain/services/event_deducers/carry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index c7191601a..3ab69d626 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -141,4 +141,3 @@ def deduce(self, dataset: EventDataset): [e.time for e in dataset.events], new_carry.time ) dataset.records.insert(pos, new_carry) - print(f"total carries: {len(new_carries)}/{len(dataset.events)}") From 60ad9a31eba9dd75cffff36338d4bc4ee0f92b04 Mon Sep 17 00:00:00 2001 From: lodevt Date: Tue, 7 Jan 2025 21:19:12 +0100 Subject: [PATCH 07/22] using pitch.distance_between to always have the distance in meters --- .../domain/services/event_deducers/carry.py | 56 +++++++++---------- kloppy/tests/test_event_deducer.py | 36 +++++++++++- 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index 3ab69d626..9bb70115a 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -1,17 +1,12 @@ import bisect import uuid from datetime import timedelta -from typing import List - -import pandas as pd from kloppy.domain import ( EventDataset, - Event, EventType, BodyPart, CarryResult, - CarryEvent, GenericEvent, EventFactory, Unit, @@ -22,36 +17,32 @@ class CarryDeducer(EventDatasetDeduducer): - def deduce(self, dataset: EventDataset): - event_factory = EventFactory() + min_carry_length_meters = 3 + max_carry_length_meters = 60 + max_carry_duration = timedelta(seconds=10) + event_factory = EventFactory() - # TODO: config - min_carry_length = 3 - max_carry_length = 60 - max_carry_duration = timedelta(seconds=10) - - unit = Unit("m") - min_carry_length = unit.convert( - dataset.metadata.coordinate_system.pitch_dimensions.unit, - min_carry_length, - ) - max_carry_length = unit.convert( - dataset.metadata.coordinate_system.pitch_dimensions.unit, - max_carry_length, - ) + def deduce(self, dataset: EventDataset): + pitch = dataset.metadata.pitch_dimensions new_carries = [] valid_event_types = [ EventType.PASS, - EventType.SHOT, EventType.TAKE_ON, - EventType.INTERCEPTION, EventType.DUEL, EventType.RECOVERY, EventType.MISCONTROL, EventType.GOALKEEPER, - EventType.PRESSURE, ] + valid_next_event_types = [ + EventType.PASS, + EventType.SHOT, + EventType.TAKE_ON, + EventType.DUEL, + EventType.MISCONTROL, + EventType.GOALKEEPER, + ] + for idx, event in enumerate(dataset.events): if event.event_type not in valid_event_types: continue @@ -61,14 +52,14 @@ def deduce(self, dataset: EventDataset): next_event = dataset.events[idx + idx_plus] if isinstance(next_event, GenericEvent): - idx += 1 + idx_plus += 1 continue else: generic_next_event = False if not event.team.team_id == next_event.team.team_id: continue - if next_event.event_type not in valid_event_types: + if next_event.event_type not in valid_next_event_types: continue # not headed shot if ( @@ -91,16 +82,19 @@ def deduce(self, dataset: EventDataset): new_coord = next_event.coordinates + distance_meters = pitch.distance_between( + new_coord, last_coord, Unit.METERS + ) # Not far enough - if new_coord.distance_to(last_coord) < min_carry_length: + if distance_meters < self.min_carry_length_meters: continue # Too far - if new_coord.distance_to(last_coord) > max_carry_length: + if distance_meters > self.max_carry_length_meters: continue dt = next_event.timestamp - event.timestamp # not same phase - if dt > max_carry_duration: + if dt > self.max_carry_duration: continue # not same period if not event.period.id == next_event.period.id: @@ -115,7 +109,7 @@ def deduce(self, dataset: EventDataset): ) generic_event_args = { - "event_id": str(uuid.uuid4()), + "event_id": f"{str(uuid.uuid4())}", "coordinates": last_coord, "team": next_event.team, "player": next_event.player, @@ -131,7 +125,7 @@ def deduce(self, dataset: EventDataset): "end_coordinates": new_coord, "end_timestamp": next_event.timestamp, } - new_carry = event_factory.build_carry( + new_carry = self.event_factory.build_carry( **carry_event_args, **generic_event_args ) new_carries.append(new_carry) diff --git a/kloppy/tests/test_event_deducer.py b/kloppy/tests/test_event_deducer.py index cce9fd4de..f44588b53 100644 --- a/kloppy/tests/test_event_deducer.py +++ b/kloppy/tests/test_event_deducer.py @@ -6,6 +6,7 @@ EventDataset, FormationType, CarryEvent, + Unit, ) from kloppy.domain.services.state_builder.builder import StateBuilder from kloppy.utils import performance_logging @@ -15,15 +16,41 @@ class TestEventDeducer: """""" - def _load_dataset(self, base_dir, base_filename="statsperform"): + def _load_dataset_statsperform( + self, base_dir, base_filename="statsperform" + ): return statsperform.load_event( ma1_data=base_dir / f"files/{base_filename}_event_ma1.json", ma3_data=base_dir / f"files/{base_filename}_event_ma3.json", - coordinates="statsbomb", + ) + + def _load_dataset_statsbomb( + self, base_dir, base_filename="statsbomb", event_types=None + ): + return statsbomb.load( + event_data=base_dir / f"files/{base_filename}_event.json", + lineup_data=base_dir / f"files/{base_filename}_lineup.json", + event_types=event_types, ) def test_carry_deducer(self, base_dir): - dataset = self._load_dataset(base_dir) + dataset_with_carries = self._load_dataset_statsbomb(base_dir) + pitch = dataset_with_carries.metadata.pitch_dimensions + all_statsbomb_caries = [ + carry + for carry in dataset_with_carries.find_all("carry") + if pitch.distance_between( + carry.coordinates, carry.end_coordinates, Unit.METERS + ) + >= 3 + ] + + dataset = self._load_dataset_statsbomb( + base_dir, + event_types=[ + event.value for event in EventType if event.value != "CARRY" + ], + ) with performance_logging("deduce_events"): dataset.add_deduced_event(EventType.CARRY) @@ -32,3 +59,6 @@ def test_carry_deducer(self, base_dir): # Assert end location is equal to start location of next action assert carry.end_coordinates == dataset.events[index + 1].coordinates assert carry.player == dataset.events[index + 1].player + all_carries = dataset.find_all("carry") + print("Original number of carries", len(all_statsbomb_caries)) + print("Generated amount of carries", len(all_carries)) From dd1f8f9c9bb755cd169633dddd938fa130b3dd8f Mon Sep 17 00:00:00 2001 From: lodevt Date: Wed, 8 Jan 2025 10:14:38 +0100 Subject: [PATCH 08/22] calculate accuracy in testing and handle pressure actions when generating carries --- .../domain/services/event_deducers/carry.py | 19 +++-- kloppy/tests/test_event_deducer.py | 69 +++++++++++++++++-- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index 9bb70115a..75cb6b02a 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -26,19 +26,15 @@ def deduce(self, dataset: EventDataset): pitch = dataset.metadata.pitch_dimensions new_carries = [] + valid_event_types = [ - EventType.PASS, - EventType.TAKE_ON, - EventType.DUEL, - EventType.RECOVERY, - EventType.MISCONTROL, - EventType.GOALKEEPER, - ] - valid_next_event_types = [ EventType.PASS, EventType.SHOT, EventType.TAKE_ON, + EventType.CLEARANCE, + EventType.INTERCEPTION, EventType.DUEL, + EventType.RECOVERY, EventType.MISCONTROL, EventType.GOALKEEPER, ] @@ -51,7 +47,10 @@ def deduce(self, dataset: EventDataset): while idx + idx_plus < len(dataset.events) and generic_next_event: next_event = dataset.events[idx + idx_plus] - if isinstance(next_event, GenericEvent): + if next_event.event_type in [ + EventType.GENERIC, + EventType.PRESSURE, + ]: idx_plus += 1 continue else: @@ -59,7 +58,7 @@ def deduce(self, dataset: EventDataset): if not event.team.team_id == next_event.team.team_id: continue - if next_event.event_type not in valid_next_event_types: + if next_event.event_type not in valid_event_types: continue # not headed shot if ( diff --git a/kloppy/tests/test_event_deducer.py b/kloppy/tests/test_event_deducer.py index f44588b53..41c0974d4 100644 --- a/kloppy/tests/test_event_deducer.py +++ b/kloppy/tests/test_event_deducer.py @@ -1,3 +1,4 @@ +from datetime import timedelta from itertools import groupby from kloppy.domain import ( @@ -33,12 +34,66 @@ def _load_dataset_statsbomb( event_types=event_types, ) + def calculate_carry_accuracy( + self, real_carries, deduced_carries, real_carries_with_min_length + ): + def is_match(real_carry, deduced_carry): + return ( + real_carry.player + and deduced_carry.player + and real_carry.player.player_id + == deduced_carry.player.player_id + and real_carry.period == deduced_carry.period + and abs(real_carry.timestamp - deduced_carry.timestamp) + < timedelta(seconds=5) + ) + + true_positives = 0 + matched_real_carries = set() + for deduced_carry in deduced_carries: + for idx, real_carry in enumerate(real_carries): + if idx in matched_real_carries: + continue + if is_match(real_carry, deduced_carry): + true_positives += 1 + matched_real_carries.add(idx) + break + + false_negatives = 0 + matched_deduced_carries = set() + for real_carry in real_carries_with_min_length: + found_match = False + for idx, deduced_carry in enumerate(deduced_carries): + if idx in matched_deduced_carries: + continue + if is_match(real_carry, deduced_carry): + found_match = True + matched_deduced_carries.add(idx) + break + if not found_match: + false_negatives += 1 + + false_positives = len(deduced_carries) - true_positives + + accuracy = true_positives / ( + true_positives + false_positives + false_negatives + ) + + print("TP:", true_positives) + print("FP:", false_positives) + print("FN:", false_negatives) + print("accuracy:", accuracy) + + return accuracy + def test_carry_deducer(self, base_dir): dataset_with_carries = self._load_dataset_statsbomb(base_dir) pitch = dataset_with_carries.metadata.pitch_dimensions - all_statsbomb_caries = [ + + all_statsbomb_caries = dataset_with_carries.find_all("carry") + all_statsbomb_caries_with_min_length = [ carry - for carry in dataset_with_carries.find_all("carry") + for carry in all_statsbomb_caries if pitch.distance_between( carry.coordinates, carry.end_coordinates, Unit.METERS ) @@ -60,5 +115,11 @@ def test_carry_deducer(self, base_dir): assert carry.end_coordinates == dataset.events[index + 1].coordinates assert carry.player == dataset.events[index + 1].player all_carries = dataset.find_all("carry") - print("Original number of carries", len(all_statsbomb_caries)) - print("Generated amount of carries", len(all_carries)) + assert ( + self.calculate_carry_accuracy( + all_statsbomb_caries, + all_carries, + all_statsbomb_caries_with_min_length, + ) + > 0.80 + ) From fbda85d421f3ecefb0d35bf8dddfbfc841f2064d Mon Sep 17 00:00:00 2001 From: lodevt Date: Wed, 8 Jan 2025 14:55:23 +0100 Subject: [PATCH 09/22] use receive_timestamp as end_timestamp when available --- kloppy/domain/services/event_deducers/carry.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index 75cb6b02a..3105063d8 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -101,6 +101,8 @@ def deduce(self, dataset: EventDataset): if hasattr(event, "end_timestamp"): last_timestamp = event.end_timestamp + elif hasattr(event, "receive_timestamp"): + last_timestamp = event.receive_timestamp else: last_timestamp = ( event.timestamp From 85608f6ca2c91eeab99118b63a9099b8dd49cae2 Mon Sep 17 00:00:00 2001 From: lodevt Date: Wed, 8 Jan 2025 16:00:39 +0100 Subject: [PATCH 10/22] add 0.1 seconds to timestamp of generated carries so they are placed after the generic ball receipt event --- kloppy/domain/services/event_deducers/carry.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/event_deducers/carry.py index 3105063d8..a7540e05b 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/event_deducers/carry.py @@ -100,9 +100,13 @@ def deduce(self, dataset: EventDataset): continue if hasattr(event, "end_timestamp"): - last_timestamp = event.end_timestamp + last_timestamp = event.end_timestamp + timedelta( + seconds=0.1 + ) elif hasattr(event, "receive_timestamp"): - last_timestamp = event.receive_timestamp + last_timestamp = event.receive_timestamp + timedelta( + seconds=0.1 + ) else: last_timestamp = ( event.timestamp From 3a9313e0fd5e55df7eaf995796ad639b06656501 Mon Sep 17 00:00:00 2001 From: lodevt Date: Mon, 13 Jan 2025 14:53:39 +0100 Subject: [PATCH 11/22] renaming of files and variables --- kloppy/domain/models/event.py | 26 ++++++++--- .../__init__.py | 0 .../carry.py | 25 ++++++----- .../synthetic_event_generator.py} | 4 +- ...r.py => test_synthetic_event_generator.py} | 44 ++++++++++--------- 5 files changed, 57 insertions(+), 42 deletions(-) rename kloppy/domain/services/{event_deducers => synthetic_event_generators}/__init__.py (100%) rename kloppy/domain/services/{event_deducers => synthetic_event_generators}/carry.py (88%) rename kloppy/domain/services/{event_deducers/event_deducer.py => synthetic_event_generators/synthetic_event_generator.py} (54%) rename kloppy/tests/{test_event_deducer.py => test_synthetic_event_generator.py} (72%) diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index d4a191be9..e908dd241 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -30,6 +30,7 @@ from .common import DataRecord, Dataset, Player, Team from .formation import FormationType from .pitch import Point +from ...config import get_config from ...exceptions import OrphanedRecordError, InvalidFilterError, KloppyError @@ -1186,16 +1187,27 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> List[Any]: return aggregator.aggregate(self) - def add_deduced_event(self, event_type_: EventType): + def add_synthetic_event(self, event_type_: EventType): + """ + Adds synthetic events of the specified type. This method analyses the stream of events and inserts + synthetic events at the appropriate points within the dataset. + + Args: + event_type_ (EventType): The type of event to generate. (See [`EventType`][kloppy.domain.models.event.EventType]) + Supported event types are currently only [EventType.CARRY] + + Raises: + KloppyError: If the event type is not supported or invalid. + """ + event_factory = get_config("event_factory") if event_type_ == EventType.CARRY: - from kloppy.domain.services.event_deducers.carry import ( - CarryDeducer, + from kloppy.domain.services.synthetic_event_generators.carry import ( + SyntheticCarryGenerator, ) - - deducer = CarryDeducer() + synthetic_event_generator = SyntheticCarryGenerator(event_factory) else: - raise KloppyError(f"Not possible to deduce {event_type_}") - deducer.deduce(self) + raise KloppyError(f"Not possible to generate synthetic {event_type_}") + synthetic_event_generator.add_synthetic_event(self) __all__ = [ diff --git a/kloppy/domain/services/event_deducers/__init__.py b/kloppy/domain/services/synthetic_event_generators/__init__.py similarity index 100% rename from kloppy/domain/services/event_deducers/__init__.py rename to kloppy/domain/services/synthetic_event_generators/__init__.py diff --git a/kloppy/domain/services/event_deducers/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py similarity index 88% rename from kloppy/domain/services/event_deducers/carry.py rename to kloppy/domain/services/synthetic_event_generators/carry.py index a7540e05b..c22a3e499 100644 --- a/kloppy/domain/services/event_deducers/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -7,22 +7,23 @@ EventType, BodyPart, CarryResult, - GenericEvent, EventFactory, Unit, ) -from kloppy.domain.services.event_deducers.event_deducer import ( - EventDatasetDeduducer, +from kloppy.domain.services.synthetic_event_generators.synthetic_event_generator import ( + SyntheticEventGenerator, ) -class CarryDeducer(EventDatasetDeduducer): - min_carry_length_meters = 3 - max_carry_length_meters = 60 - max_carry_duration = timedelta(seconds=10) - event_factory = EventFactory() +class SyntheticCarryGenerator(SyntheticEventGenerator): + min_length_meters = 3 + max_length_meters = 60 + max_duration = timedelta(seconds=10) - def deduce(self, dataset: EventDataset): + def __init__(self, event_factory): + self.event_factory = event_factory + + def add_synthetic_event(self, dataset: EventDataset): pitch = dataset.metadata.pitch_dimensions new_carries = [] @@ -85,15 +86,15 @@ def deduce(self, dataset: EventDataset): new_coord, last_coord, Unit.METERS ) # Not far enough - if distance_meters < self.min_carry_length_meters: + if distance_meters < self.min_length_meters: continue # Too far - if distance_meters > self.max_carry_length_meters: + if distance_meters > self.max_length_meters: continue dt = next_event.timestamp - event.timestamp # not same phase - if dt > self.max_carry_duration: + if dt > self.max_duration: continue # not same period if not event.period.id == next_event.period.id: diff --git a/kloppy/domain/services/event_deducers/event_deducer.py b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py similarity index 54% rename from kloppy/domain/services/event_deducers/event_deducer.py rename to kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py index 459c75b77..cf02e5428 100644 --- a/kloppy/domain/services/event_deducers/event_deducer.py +++ b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py @@ -2,7 +2,7 @@ from kloppy.domain import EventDataset -class EventDatasetDeduducer(ABC): +class SyntheticEventGenerator(ABC): @abstractmethod - def deduce(self, dataset: EventDataset) -> EventDataset: + def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: raise NotImplementedError diff --git a/kloppy/tests/test_event_deducer.py b/kloppy/tests/test_synthetic_event_generator.py similarity index 72% rename from kloppy/tests/test_event_deducer.py rename to kloppy/tests/test_synthetic_event_generator.py index 41c0974d4..320da583b 100644 --- a/kloppy/tests/test_event_deducer.py +++ b/kloppy/tests/test_synthetic_event_generator.py @@ -14,7 +14,7 @@ from kloppy import statsbomb, statsperform -class TestEventDeducer: +class TestSyntheticEventGenerator: """""" def _load_dataset_statsperform( @@ -35,45 +35,45 @@ def _load_dataset_statsbomb( ) def calculate_carry_accuracy( - self, real_carries, deduced_carries, real_carries_with_min_length + self, real_carries, generated_carries, real_carries_with_min_length ): - def is_match(real_carry, deduced_carry): + def is_match(real_carry, generated_carry): return ( - real_carry.player - and deduced_carry.player - and real_carry.player.player_id - == deduced_carry.player.player_id - and real_carry.period == deduced_carry.period - and abs(real_carry.timestamp - deduced_carry.timestamp) - < timedelta(seconds=5) + real_carry.player + and generated_carry.player + and real_carry.player.player_id + == generated_carry.player.player_id + and real_carry.period == generated_carry.period + and abs(real_carry.timestamp - generated_carry.timestamp) + < timedelta(seconds=5) ) true_positives = 0 matched_real_carries = set() - for deduced_carry in deduced_carries: + for generated_carry in generated_carries: for idx, real_carry in enumerate(real_carries): if idx in matched_real_carries: continue - if is_match(real_carry, deduced_carry): + if is_match(real_carry, generated_carry): true_positives += 1 matched_real_carries.add(idx) break false_negatives = 0 - matched_deduced_carries = set() + matched_generated_carries = set() for real_carry in real_carries_with_min_length: found_match = False - for idx, deduced_carry in enumerate(deduced_carries): - if idx in matched_deduced_carries: + for idx, generated_carry in enumerate(generated_carries): + if idx in matched_generated_carries: continue - if is_match(real_carry, deduced_carry): + if is_match(real_carry, generated_carry): found_match = True - matched_deduced_carries.add(idx) + matched_generated_carries.add(idx) break if not found_match: false_negatives += 1 - false_positives = len(deduced_carries) - true_positives + false_positives = len(generated_carries) - true_positives accuracy = true_positives / ( true_positives + false_positives + false_negatives @@ -86,7 +86,7 @@ def is_match(real_carry, deduced_carry): return accuracy - def test_carry_deducer(self, base_dir): + def test_synthetic_carry_generator(self, base_dir): dataset_with_carries = self._load_dataset_statsbomb(base_dir) pitch = dataset_with_carries.metadata.pitch_dimensions @@ -107,8 +107,8 @@ def test_carry_deducer(self, base_dir): ], ) - with performance_logging("deduce_events"): - dataset.add_deduced_event(EventType.CARRY) + with performance_logging("generating synthetic events"): + dataset.add_synthetic_event(EventType.CARRY) carry = dataset.find("carry") index = dataset.events.index(carry) # Assert end location is equal to start location of next action @@ -123,3 +123,5 @@ def test_carry_deducer(self, base_dir): ) > 0.80 ) + + print(dataset.to_df()[:100].to_string()) From 09f7c73937e693e52dc5db736636c03ee9921f26 Mon Sep 17 00:00:00 2001 From: lodevt Date: Mon, 13 Jan 2025 15:06:44 +0100 Subject: [PATCH 12/22] formatting --- kloppy/domain/models/event.py | 5 ++++- .../tests/test_synthetic_event_generator.py | 21 +++++++------------ 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index e908dd241..8a035858c 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -1204,9 +1204,12 @@ def add_synthetic_event(self, event_type_: EventType): from kloppy.domain.services.synthetic_event_generators.carry import ( SyntheticCarryGenerator, ) + synthetic_event_generator = SyntheticCarryGenerator(event_factory) else: - raise KloppyError(f"Not possible to generate synthetic {event_type_}") + raise KloppyError( + f"Not possible to generate synthetic {event_type_}" + ) synthetic_event_generator.add_synthetic_event(self) diff --git a/kloppy/tests/test_synthetic_event_generator.py b/kloppy/tests/test_synthetic_event_generator.py index 320da583b..29ca2b5c9 100644 --- a/kloppy/tests/test_synthetic_event_generator.py +++ b/kloppy/tests/test_synthetic_event_generator.py @@ -39,13 +39,13 @@ def calculate_carry_accuracy( ): def is_match(real_carry, generated_carry): return ( - real_carry.player - and generated_carry.player - and real_carry.player.player_id - == generated_carry.player.player_id - and real_carry.period == generated_carry.period - and abs(real_carry.timestamp - generated_carry.timestamp) - < timedelta(seconds=5) + real_carry.player + and generated_carry.player + and real_carry.player.player_id + == generated_carry.player.player_id + and real_carry.period == generated_carry.period + and abs(real_carry.timestamp - generated_carry.timestamp) + < timedelta(seconds=5) ) true_positives = 0 @@ -109,11 +109,6 @@ def test_synthetic_carry_generator(self, base_dir): with performance_logging("generating synthetic events"): dataset.add_synthetic_event(EventType.CARRY) - carry = dataset.find("carry") - index = dataset.events.index(carry) - # Assert end location is equal to start location of next action - assert carry.end_coordinates == dataset.events[index + 1].coordinates - assert carry.player == dataset.events[index + 1].player all_carries = dataset.find_all("carry") assert ( self.calculate_carry_accuracy( @@ -123,5 +118,3 @@ def test_synthetic_carry_generator(self, base_dir): ) > 0.80 ) - - print(dataset.to_df()[:100].to_string()) From 26c1ba1cfd027408056bd23036830b496de6b98f Mon Sep 17 00:00:00 2001 From: lodevt Date: Mon, 13 Jan 2025 15:16:11 +0100 Subject: [PATCH 13/22] add synthetic-prefix in event_id --- kloppy/domain/models/event.py | 4 +-- .../synthetic_event_generators/carry.py | 29 +++++-------------- .../tests/test_synthetic_event_generator.py | 6 ---- 3 files changed, 8 insertions(+), 31 deletions(-) diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 8a035858c..9a8e61f35 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -30,7 +30,6 @@ from .common import DataRecord, Dataset, Player, Team from .formation import FormationType from .pitch import Point -from ...config import get_config from ...exceptions import OrphanedRecordError, InvalidFilterError, KloppyError @@ -1199,13 +1198,12 @@ def add_synthetic_event(self, event_type_: EventType): Raises: KloppyError: If the event type is not supported or invalid. """ - event_factory = get_config("event_factory") if event_type_ == EventType.CARRY: from kloppy.domain.services.synthetic_event_generators.carry import ( SyntheticCarryGenerator, ) - synthetic_event_generator = SyntheticCarryGenerator(event_factory) + synthetic_event_generator = SyntheticCarryGenerator() else: raise KloppyError( f"Not possible to generate synthetic {event_type_}" diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index c22a3e499..8eba9cbdb 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -1,4 +1,3 @@ -import bisect import uuid from datetime import timedelta @@ -7,8 +6,8 @@ EventType, BodyPart, CarryResult, - EventFactory, Unit, + EventFactory, ) from kloppy.domain.services.synthetic_event_generators.synthetic_event_generator import ( SyntheticEventGenerator, @@ -19,15 +18,11 @@ class SyntheticCarryGenerator(SyntheticEventGenerator): min_length_meters = 3 max_length_meters = 60 max_duration = timedelta(seconds=10) - - def __init__(self, event_factory): - self.event_factory = event_factory + event_factory = EventFactory() def add_synthetic_event(self, dataset: EventDataset): pitch = dataset.metadata.pitch_dimensions - new_carries = [] - valid_event_types = [ EventType.PASS, EventType.SHOT, @@ -101,13 +96,9 @@ def add_synthetic_event(self, dataset: EventDataset): continue if hasattr(event, "end_timestamp"): - last_timestamp = event.end_timestamp + timedelta( - seconds=0.1 - ) + last_timestamp = event.end_timestamp elif hasattr(event, "receive_timestamp"): - last_timestamp = event.receive_timestamp + timedelta( - seconds=0.1 - ) + last_timestamp = event.receive_timestamp else: last_timestamp = ( event.timestamp @@ -115,7 +106,7 @@ def add_synthetic_event(self, dataset: EventDataset): ) generic_event_args = { - "event_id": f"{str(uuid.uuid4())}", + "event_id": f"synthetic-{str(uuid.uuid4())}", "coordinates": last_coord, "team": next_event.team, "player": next_event.player, @@ -123,7 +114,7 @@ def add_synthetic_event(self, dataset: EventDataset): "ball_state": event.ball_state, "period": next_event.period, "timestamp": last_timestamp, - "raw_event": {}, + "raw_event": None, } carry_event_args = { "result": CarryResult.COMPLETE, @@ -134,10 +125,4 @@ def add_synthetic_event(self, dataset: EventDataset): new_carry = self.event_factory.build_carry( **carry_event_args, **generic_event_args ) - new_carries.append(new_carry) - - for new_carry in new_carries: - pos = bisect.bisect_left( - [e.time for e in dataset.events], new_carry.time - ) - dataset.records.insert(pos, new_carry) + dataset.records.insert(idx + idx_plus, new_carry) diff --git a/kloppy/tests/test_synthetic_event_generator.py b/kloppy/tests/test_synthetic_event_generator.py index 29ca2b5c9..ae063cdee 100644 --- a/kloppy/tests/test_synthetic_event_generator.py +++ b/kloppy/tests/test_synthetic_event_generator.py @@ -1,15 +1,9 @@ from datetime import timedelta -from itertools import groupby from kloppy.domain import ( EventType, - Event, - EventDataset, - FormationType, - CarryEvent, Unit, ) -from kloppy.domain.services.state_builder.builder import StateBuilder from kloppy.utils import performance_logging from kloppy import statsbomb, statsperform From 0ba8ecc2d3f6848dd3d7a67614539c4acee34171 Mon Sep 17 00:00:00 2001 From: lodevt Date: Mon, 13 Jan 2025 15:26:53 +0100 Subject: [PATCH 14/22] read body_part from qualifiers --- .../synthetic_event_generators/carry.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index 8eba9cbdb..222526cb3 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -57,16 +57,15 @@ def add_synthetic_event(self, dataset: EventDataset): if next_event.event_type not in valid_event_types: continue # not headed shot - if ( - (hasattr(next_event, "body_part")) - and (next_event.event_type == EventType.SHOT) - and ( - next_event.body_part.type.isin( - [BodyPart.HEAD, BodyPart.HEAD_OTHER] - ) - ) - ): - continue + if next_event.qualifiers: + for qualifier in next_event.qualifiers: + if ( + qualifier.name == "body_part" + and next_event.event_type == EventType.SHOT + and qualifier.value + in [BodyPart.HEAD, BodyPart.HEAD_OTHER] + ): + continue if hasattr(event, "end_coordinates"): last_coord = event.end_coordinates From 4019f49edb6451a1f68bceef6dc4264233089ca4 Mon Sep 17 00:00:00 2001 From: lodevt Date: Mon, 13 Jan 2025 15:32:21 +0100 Subject: [PATCH 15/22] refactoring --- .../synthetic_event_generators/carry.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index 222526cb3..445d71836 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -39,8 +39,8 @@ def add_synthetic_event(self, dataset: EventDataset): if event.event_type not in valid_event_types: continue idx_plus = 1 - generic_next_event = True - while idx + idx_plus < len(dataset.events) and generic_next_event: + go_to_next_event = True + while idx + idx_plus < len(dataset.events) and go_to_next_event: next_event = dataset.events[idx + idx_plus] if next_event.event_type in [ @@ -50,22 +50,12 @@ def add_synthetic_event(self, dataset: EventDataset): idx_plus += 1 continue else: - generic_next_event = False + go_to_next_event = False if not event.team.team_id == next_event.team.team_id: continue if next_event.event_type not in valid_event_types: continue - # not headed shot - if next_event.qualifiers: - for qualifier in next_event.qualifiers: - if ( - qualifier.name == "body_part" - and next_event.event_type == EventType.SHOT - and qualifier.value - in [BodyPart.HEAD, BodyPart.HEAD_OTHER] - ): - continue if hasattr(event, "end_coordinates"): last_coord = event.end_coordinates @@ -94,6 +84,14 @@ def add_synthetic_event(self, dataset: EventDataset): if not event.period.id == next_event.period.id: continue + # not headed shot + if next_event.event_type == EventType.SHOT and any( + qualifier.name == "body_part" + and qualifier.value in [BodyPart.HEAD, BodyPart.HEAD_OTHER] + for qualifier in next_event.qualifiers or [] + ): + continue + if hasattr(event, "end_timestamp"): last_timestamp = event.end_timestamp elif hasattr(event, "receive_timestamp"): From 7cb29fcaf42ecd75db7e3ca5da6252d069336ea7 Mon Sep 17 00:00:00 2001 From: lodevt Date: Tue, 14 Jan 2025 15:45:56 +0100 Subject: [PATCH 16/22] add event_factory parameter to SyntheticEventGenerator --- kloppy/domain/models/event.py | 5 +++-- .../synthetic_event_generator.py | 9 ++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 9a8e61f35..086e853e8 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -1186,7 +1186,7 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> List[Any]: return aggregator.aggregate(self) - def add_synthetic_event(self, event_type_: EventType): + def add_synthetic_event(self, event_type_: EventType, event_factory_=None): """ Adds synthetic events of the specified type. This method analyses the stream of events and inserts synthetic events at the appropriate points within the dataset. @@ -1194,6 +1194,7 @@ def add_synthetic_event(self, event_type_: EventType): Args: event_type_ (EventType): The type of event to generate. (See [`EventType`][kloppy.domain.models.event.EventType]) Supported event types are currently only [EventType.CARRY] + event_factory_ (EventFactory): Optional event factory to generate the events Raises: KloppyError: If the event type is not supported or invalid. @@ -1203,7 +1204,7 @@ def add_synthetic_event(self, event_type_: EventType): SyntheticCarryGenerator, ) - synthetic_event_generator = SyntheticCarryGenerator() + synthetic_event_generator = SyntheticCarryGenerator(event_factory_) else: raise KloppyError( f"Not possible to generate synthetic {event_type_}" diff --git a/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py index cf02e5428..86f213af3 100644 --- a/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py +++ b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py @@ -1,8 +1,15 @@ from abc import ABC, abstractmethod -from kloppy.domain import EventDataset +from typing import Optional + +from kloppy.domain import EventDataset, EventFactory class SyntheticEventGenerator(ABC): + def __init__(self, event_factory: Optional[EventFactory] = None): + if not event_factory: + event_factory = EventFactory() + self.event_factory = event_factory + @abstractmethod def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: raise NotImplementedError From 7fbef92f94dc8d2607c04e8cfa1c293676d04b62 Mon Sep 17 00:00:00 2001 From: lodevt Date: Wed, 15 Jan 2025 09:31:01 +0100 Subject: [PATCH 17/22] allow passing configuration parameters to Synthetic Event Generator --- kloppy/domain/models/event.py | 27 +++++++++++++------ .../synthetic_event_generators/carry.py | 10 ++++--- .../synthetic_event_generator.py | 8 +----- .../tests/test_synthetic_event_generator.py | 25 ++++++++++++----- 4 files changed, 45 insertions(+), 25 deletions(-) diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 086e853e8..8de04095b 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -1186,25 +1186,36 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> List[Any]: return aggregator.aggregate(self) - def add_synthetic_event(self, event_type_: EventType, event_factory_=None): + def add_synthetic_event( + self, event_type_: EventType, event_factory_=None, **kwargs + ): """ - Adds synthetic events of the specified type. This method analyses the stream of events and inserts - synthetic events at the appropriate points within the dataset. + Adds synthetic events of the specified type to the event dataset. This method analyzes the stream of + events and inserts synthetic events at the appropriate points within the dataset based on the event type. Args: - event_type_ (EventType): The type of event to generate. (See [`EventType`][kloppy.domain.models.event.EventType]) - Supported event types are currently only [EventType.CARRY] - event_factory_ (EventFactory): Optional event factory to generate the events + event_type_ (EventType): The type of event to generate. The supported event types are currently: + - `EventType.CARRY`: Generates carry events. + event_factory_ (Optional[EventFactory]): An optional event factory to create the events. If not provided, + a default event factory will be used. + **kwargs: Additional configuration parameters passed to the specific synthetic event generator class. + The expected parameters depend on the type of event being generated (e.g., `SyntheticCarryGenerator`). Raises: - KloppyError: If the event type is not supported or invalid. + KloppyError: If the provided `event_type_` is not supported or invalid. + + Example: + To generate a synthetic carry event: + add_synthetic_event(EventType.CARRY, event_factory=my_event_factory, min_length_meters=3, max_length_meters=60, max_duration=timedelta(seconds=10)) """ if event_type_ == EventType.CARRY: from kloppy.domain.services.synthetic_event_generators.carry import ( SyntheticCarryGenerator, ) - synthetic_event_generator = SyntheticCarryGenerator(event_factory_) + synthetic_event_generator = SyntheticCarryGenerator( + event_factory_, **kwargs + ) else: raise KloppyError( f"Not possible to generate synthetic {event_type_}" diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index 445d71836..3b87ac30f 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -1,5 +1,6 @@ import uuid from datetime import timedelta +from typing import Optional from kloppy.domain import ( EventDataset, @@ -15,10 +16,11 @@ class SyntheticCarryGenerator(SyntheticEventGenerator): - min_length_meters = 3 - max_length_meters = 60 - max_duration = timedelta(seconds=10) - event_factory = EventFactory() + def __init__(self, event_factory: Optional[EventFactory] = None, **kwargs): + self.event_factory = event_factory or EventFactory() + self.min_length_meters = kwargs.get("min_length_meters") or 3 + self.max_length_meters = kwargs.get("max_length_meters") or 60 + self.max_duration = kwargs.get("max_duration") or timedelta(seconds=10) def add_synthetic_event(self, dataset: EventDataset): pitch = dataset.metadata.pitch_dimensions diff --git a/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py index 86f213af3..1b265116b 100644 --- a/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py +++ b/kloppy/domain/services/synthetic_event_generators/synthetic_event_generator.py @@ -1,15 +1,9 @@ from abc import ABC, abstractmethod -from typing import Optional -from kloppy.domain import EventDataset, EventFactory +from kloppy.domain import EventDataset class SyntheticEventGenerator(ABC): - def __init__(self, event_factory: Optional[EventFactory] = None): - if not event_factory: - event_factory = EventFactory() - self.event_factory = event_factory - @abstractmethod def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: raise NotImplementedError diff --git a/kloppy/tests/test_synthetic_event_generator.py b/kloppy/tests/test_synthetic_event_generator.py index ae063cdee..d1c1d26cd 100644 --- a/kloppy/tests/test_synthetic_event_generator.py +++ b/kloppy/tests/test_synthetic_event_generator.py @@ -84,14 +84,22 @@ def test_synthetic_carry_generator(self, base_dir): dataset_with_carries = self._load_dataset_statsbomb(base_dir) pitch = dataset_with_carries.metadata.pitch_dimensions + min_length_meters = 3 + max_length_meters = 60 + max_duration = timedelta(seconds=10) + all_statsbomb_caries = dataset_with_carries.find_all("carry") - all_statsbomb_caries_with_min_length = [ + all_qualifying_statsbomb_queries = [ carry for carry in all_statsbomb_caries - if pitch.distance_between( - carry.coordinates, carry.end_coordinates, Unit.METERS + if ( + min_length_meters + <= pitch.distance_between( + carry.coordinates, carry.end_coordinates, Unit.METERS + ) + <= max_length_meters + and carry.end_timestamp - carry.timestamp < max_duration ) - >= 3 ] dataset = self._load_dataset_statsbomb( @@ -102,13 +110,18 @@ def test_synthetic_carry_generator(self, base_dir): ) with performance_logging("generating synthetic events"): - dataset.add_synthetic_event(EventType.CARRY) + dataset.add_synthetic_event( + EventType.CARRY, + min_length_meters=min_length_meters, + max_length_meters=max_length_meters, + max_duration=max_duration, + ) all_carries = dataset.find_all("carry") assert ( self.calculate_carry_accuracy( all_statsbomb_caries, all_carries, - all_statsbomb_caries_with_min_length, + all_qualifying_statsbomb_queries, ) > 0.80 ) From a5926ec4e8aef1d08b20d83da4b6e9d801491ad5 Mon Sep 17 00:00:00 2001 From: lodevt Date: Thu, 16 Jan 2025 14:02:22 +0100 Subject: [PATCH 18/22] use "carry-" to generate deterministic event ids for synthetic carries" --- kloppy/domain/services/synthetic_event_generators/carry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index 3b87ac30f..8c1bfbad9 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -103,9 +103,8 @@ def add_synthetic_event(self, dataset: EventDataset): event.timestamp + (next_event.timestamp - event.timestamp) / 10 ) - generic_event_args = { - "event_id": f"synthetic-{str(uuid.uuid4())}", + "event_id": f"carry-{event.event_id}", "coordinates": last_coord, "team": next_event.team, "player": next_event.player, From 7979fb3c9195d637d2c531d16830796d696d7552 Mon Sep 17 00:00:00 2001 From: lodevt Date: Thu, 16 Jan 2025 16:18:57 +0100 Subject: [PATCH 19/22] ensure valid timestamps --- .../services/synthetic_event_generators/carry.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index 8c1bfbad9..f433c4d5f 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -94,9 +94,15 @@ def add_synthetic_event(self, dataset: EventDataset): ): continue - if hasattr(event, "end_timestamp"): + if ( + hasattr(event, "end_timestamp") + and event.end_timestamp is not None + ): last_timestamp = event.end_timestamp - elif hasattr(event, "receive_timestamp"): + elif ( + hasattr(event, "receive_timestamp") + and event.receive_timestamp is not None + ): last_timestamp = event.receive_timestamp else: last_timestamp = ( From d738e317eaef5b5b7a29f108784b92cbaf406ab1 Mon Sep 17 00:00:00 2001 From: lodevt Date: Fri, 17 Jan 2025 09:41:07 +0100 Subject: [PATCH 20/22] return dataset after adding synthetic events --- kloppy/domain/models/event.py | 2 +- kloppy/domain/services/synthetic_event_generators/carry.py | 3 ++- kloppy/tests/test_synthetic_event_generator.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 8de04095b..167145816 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -1220,7 +1220,7 @@ def add_synthetic_event( raise KloppyError( f"Not possible to generate synthetic {event_type_}" ) - synthetic_event_generator.add_synthetic_event(self) + return synthetic_event_generator.add_synthetic_event(self) __all__ = [ diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index f433c4d5f..0fe0c18a8 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -22,7 +22,7 @@ def __init__(self, event_factory: Optional[EventFactory] = None, **kwargs): self.max_length_meters = kwargs.get("max_length_meters") or 60 self.max_duration = kwargs.get("max_duration") or timedelta(seconds=10) - def add_synthetic_event(self, dataset: EventDataset): + def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: pitch = dataset.metadata.pitch_dimensions valid_event_types = [ @@ -130,3 +130,4 @@ def add_synthetic_event(self, dataset: EventDataset): **carry_event_args, **generic_event_args ) dataset.records.insert(idx + idx_plus, new_carry) + return dataset diff --git a/kloppy/tests/test_synthetic_event_generator.py b/kloppy/tests/test_synthetic_event_generator.py index d1c1d26cd..e3310fd55 100644 --- a/kloppy/tests/test_synthetic_event_generator.py +++ b/kloppy/tests/test_synthetic_event_generator.py @@ -110,7 +110,7 @@ def test_synthetic_carry_generator(self, base_dir): ) with performance_logging("generating synthetic events"): - dataset.add_synthetic_event( + dataset = dataset.add_synthetic_event( EventType.CARRY, min_length_meters=min_length_meters, max_length_meters=max_length_meters, From e60667c5c31afbccc0a035db55c607dc7b030886 Mon Sep 17 00:00:00 2001 From: lodevt Date: Fri, 17 Jan 2025 14:10:24 +0100 Subject: [PATCH 21/22] some refactoring --- .../synthetic_event_generators/carry.py | 189 +++++++++--------- 1 file changed, 97 insertions(+), 92 deletions(-) diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index 0fe0c18a8..b4fd4424c 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -1,19 +1,40 @@ -import uuid from datetime import timedelta from typing import Optional from kloppy.domain import ( EventDataset, - EventType, BodyPart, CarryResult, Unit, EventFactory, + PassEvent, + ShotEvent, + TakeOnEvent, + ClearanceEvent, + InterceptionEvent, + DuelEvent, + RecoveryEvent, + MiscontrolEvent, + GoalkeeperEvent, + GenericEvent, ) +from kloppy.domain.models.event import PressureEvent from kloppy.domain.services.synthetic_event_generators.synthetic_event_generator import ( SyntheticEventGenerator, ) +VALID_EVENT = ( + PassEvent, + ShotEvent, + TakeOnEvent, + ClearanceEvent, + InterceptionEvent, + DuelEvent, + RecoveryEvent, + MiscontrolEvent, + GoalkeeperEvent, +) + class SyntheticCarryGenerator(SyntheticEventGenerator): def __init__(self, event_factory: Optional[EventFactory] = None, **kwargs): @@ -25,109 +46,93 @@ def __init__(self, event_factory: Optional[EventFactory] = None, **kwargs): def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: pitch = dataset.metadata.pitch_dimensions - valid_event_types = [ - EventType.PASS, - EventType.SHOT, - EventType.TAKE_ON, - EventType.CLEARANCE, - EventType.INTERCEPTION, - EventType.DUEL, - EventType.RECOVERY, - EventType.MISCONTROL, - EventType.GOALKEEPER, - ] - for idx, event in enumerate(dataset.events): - if event.event_type not in valid_event_types: + if not isinstance(event, VALID_EVENT): continue idx_plus = 1 - go_to_next_event = True - while idx + idx_plus < len(dataset.events) and go_to_next_event: + next_event = None + while idx + idx_plus < len(dataset.events): next_event = dataset.events[idx + idx_plus] - if next_event.event_type in [ - EventType.GENERIC, - EventType.PRESSURE, - ]: + if isinstance(next_event, (GenericEvent, PressureEvent)): idx_plus += 1 continue else: - go_to_next_event = False - if not event.team.team_id == next_event.team.team_id: - continue - - if next_event.event_type not in valid_event_types: - continue + break - if hasattr(event, "end_coordinates"): - last_coord = event.end_coordinates - elif hasattr(event, "receiver_coordinates"): - last_coord = event.receiver_coordinates - else: - last_coord = event.coordinates + if not isinstance(next_event, VALID_EVENT): + continue + if not event.team.team_id == next_event.team.team_id: + continue + if hasattr(event, "end_coordinates"): + last_coord = event.end_coordinates + elif hasattr(event, "receiver_coordinates"): + last_coord = event.receiver_coordinates + else: + last_coord = event.coordinates - new_coord = next_event.coordinates + new_coord = next_event.coordinates - distance_meters = pitch.distance_between( - new_coord, last_coord, Unit.METERS - ) - # Not far enough - if distance_meters < self.min_length_meters: - continue - # Too far - if distance_meters > self.max_length_meters: - continue + distance_meters = pitch.distance_between( + new_coord, last_coord, Unit.METERS + ) + # Not far enough + if distance_meters < self.min_length_meters: + continue + # Too far + if distance_meters > self.max_length_meters: + continue - dt = next_event.timestamp - event.timestamp - # not same phase - if dt > self.max_duration: - continue - # not same period - if not event.period.id == next_event.period.id: - continue + dt = next_event.timestamp - event.timestamp + # not same phase + if dt > self.max_duration: + continue + # not same period + if not event.period.id == next_event.period.id: + continue - # not headed shot - if next_event.event_type == EventType.SHOT and any( - qualifier.name == "body_part" - and qualifier.value in [BodyPart.HEAD, BodyPart.HEAD_OTHER] - for qualifier in next_event.qualifiers or [] - ): - continue + # not headed shot + if isinstance(next_event, ShotEvent) and any( + qualifier.name == "body_part" + and qualifier.value in [BodyPart.HEAD, BodyPart.HEAD_OTHER] + for qualifier in next_event.qualifiers or [] + ): + continue - if ( - hasattr(event, "end_timestamp") - and event.end_timestamp is not None - ): - last_timestamp = event.end_timestamp - elif ( - hasattr(event, "receive_timestamp") - and event.receive_timestamp is not None - ): - last_timestamp = event.receive_timestamp - else: - last_timestamp = ( - event.timestamp - + (next_event.timestamp - event.timestamp) / 10 - ) - generic_event_args = { - "event_id": f"carry-{event.event_id}", - "coordinates": last_coord, - "team": next_event.team, - "player": next_event.player, - "ball_owning_team": next_event.ball_owning_team, - "ball_state": event.ball_state, - "period": next_event.period, - "timestamp": last_timestamp, - "raw_event": None, - } - carry_event_args = { - "result": CarryResult.COMPLETE, - "qualifiers": None, - "end_coordinates": new_coord, - "end_timestamp": next_event.timestamp, - } - new_carry = self.event_factory.build_carry( - **carry_event_args, **generic_event_args + if ( + hasattr(event, "end_timestamp") + and event.end_timestamp is not None + ): + last_timestamp = event.end_timestamp + elif ( + hasattr(event, "receive_timestamp") + and event.receive_timestamp is not None + ): + last_timestamp = event.receive_timestamp + else: + last_timestamp = ( + event.timestamp + + (next_event.timestamp - event.timestamp) / 10 ) - dataset.records.insert(idx + idx_plus, new_carry) + generic_event_args = { + "event_id": f"carry-{event.event_id}", + "coordinates": last_coord, + "team": next_event.team, + "player": next_event.player, + "ball_owning_team": next_event.ball_owning_team, + "ball_state": event.ball_state, + "period": next_event.period, + "timestamp": last_timestamp, + "raw_event": None, + } + carry_event_args = { + "result": CarryResult.COMPLETE, + "qualifiers": None, + "end_coordinates": new_coord, + "end_timestamp": next_event.timestamp, + } + new_carry = self.event_factory.build_carry( + **carry_event_args, **generic_event_args + ) + dataset.records.insert(idx + idx_plus, new_carry) return dataset From 66193b03b2282732c3f3a4f059cbcbb5542f788d Mon Sep 17 00:00:00 2001 From: lodevt Date: Mon, 20 Jan 2025 10:23:42 +0100 Subject: [PATCH 22/22] no carries before set piece --- kloppy/domain/services/synthetic_event_generators/carry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kloppy/domain/services/synthetic_event_generators/carry.py b/kloppy/domain/services/synthetic_event_generators/carry.py index b4fd4424c..a57b3fa9e 100644 --- a/kloppy/domain/services/synthetic_event_generators/carry.py +++ b/kloppy/domain/services/synthetic_event_generators/carry.py @@ -18,7 +18,7 @@ GoalkeeperEvent, GenericEvent, ) -from kloppy.domain.models.event import PressureEvent +from kloppy.domain.models.event import PressureEvent, SetPieceQualifier from kloppy.domain.services.synthetic_event_generators.synthetic_event_generator import ( SyntheticEventGenerator, ) @@ -64,6 +64,8 @@ def add_synthetic_event(self, dataset: EventDataset) -> EventDataset: continue if not event.team.team_id == next_event.team.team_id: continue + if next_event.get_qualifier_value(SetPieceQualifier) is not None: + continue if hasattr(event, "end_coordinates"): last_coord = event.end_coordinates elif hasattr(event, "receiver_coordinates"):