Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,42 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> List[Any]:

return aggregator.aggregate(self)

def add_synthetic_event(
self, event_type_: EventType, event_factory_=None, **kwargs
):
"""
Adds synthetic events of the specified type to the event dataset. This method analyzes the stream of
events and inserts synthetic events at the appropriate points within the dataset based on the event type.

Args:
event_type_ (EventType): The type of event to generate. The supported event types are currently:
- `EventType.CARRY`: Generates carry events.
event_factory_ (Optional[EventFactory]): An optional event factory to create the events. If not provided,
a default event factory will be used.
**kwargs: Additional configuration parameters passed to the specific synthetic event generator class.
The expected parameters depend on the type of event being generated (e.g., `SyntheticCarryGenerator`).

Raises:
KloppyError: If the provided `event_type_` is not supported or invalid.

Example:
To generate a synthetic carry event:
add_synthetic_event(EventType.CARRY, event_factory=my_event_factory, min_length_meters=3, max_length_meters=60, max_duration=timedelta(seconds=10))
"""
if event_type_ == EventType.CARRY:
from kloppy.domain.services.synthetic_event_generators.carry import (
SyntheticCarryGenerator,
)

synthetic_event_generator = SyntheticCarryGenerator(
event_factory_, **kwargs
)
else:
raise KloppyError(
f"Not possible to generate synthetic {event_type_}"
)
return synthetic_event_generator.add_synthetic_event(self)


__all__ = [
"EnumQualifier",
Expand Down
Empty file.
140 changes: 140 additions & 0 deletions kloppy/domain/services/synthetic_event_generators/carry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from datetime import timedelta
from typing import Optional

from kloppy.domain import (
EventDataset,
BodyPart,
CarryResult,
Unit,
EventFactory,
PassEvent,
ShotEvent,
TakeOnEvent,
ClearanceEvent,
InterceptionEvent,
DuelEvent,
RecoveryEvent,
MiscontrolEvent,
GoalkeeperEvent,
GenericEvent,
)
from kloppy.domain.models.event import PressureEvent, SetPieceQualifier
from kloppy.domain.services.synthetic_event_generators.synthetic_event_generator import (
SyntheticEventGenerator,
)

VALID_EVENT = (
PassEvent,
ShotEvent,
TakeOnEvent,
ClearanceEvent,
InterceptionEvent,
DuelEvent,
RecoveryEvent,
MiscontrolEvent,
GoalkeeperEvent,
)


class SyntheticCarryGenerator(SyntheticEventGenerator):
def __init__(self, event_factory: Optional[EventFactory] = None, **kwargs):
self.event_factory = event_factory or EventFactory()
self.min_length_meters = kwargs.get("min_length_meters") or 3
self.max_length_meters = kwargs.get("max_length_meters") or 60
self.max_duration = kwargs.get("max_duration") or timedelta(seconds=10)

def add_synthetic_event(self, dataset: EventDataset) -> EventDataset:
pitch = dataset.metadata.pitch_dimensions

for idx, event in enumerate(dataset.events):
if not isinstance(event, VALID_EVENT):
continue
idx_plus = 1
next_event = None
while idx + idx_plus < len(dataset.events):
next_event = dataset.events[idx + idx_plus]

if isinstance(next_event, (GenericEvent, PressureEvent)):
idx_plus += 1
continue
else:
break

if not isinstance(next_event, VALID_EVENT):
continue
if not event.team.team_id == next_event.team.team_id:
continue
if next_event.get_qualifier_value(SetPieceQualifier) is not None:
continue
if hasattr(event, "end_coordinates"):
last_coord = event.end_coordinates
elif hasattr(event, "receiver_coordinates"):
last_coord = event.receiver_coordinates
else:
last_coord = event.coordinates

new_coord = next_event.coordinates

distance_meters = pitch.distance_between(
new_coord, last_coord, Unit.METERS
)
# Not far enough
if distance_meters < self.min_length_meters:
continue
# Too far
if distance_meters > self.max_length_meters:
continue

dt = next_event.timestamp - event.timestamp
# not same phase
if dt > self.max_duration:
continue
# not same period
if not event.period.id == next_event.period.id:
continue

# not headed shot
if isinstance(next_event, ShotEvent) and any(
qualifier.name == "body_part"
and qualifier.value in [BodyPart.HEAD, BodyPart.HEAD_OTHER]
for qualifier in next_event.qualifiers or []
):
continue

if (
hasattr(event, "end_timestamp")
and event.end_timestamp is not None
):
last_timestamp = event.end_timestamp
elif (
hasattr(event, "receive_timestamp")
and event.receive_timestamp is not None
):
last_timestamp = event.receive_timestamp
else:
last_timestamp = (
event.timestamp
+ (next_event.timestamp - event.timestamp) / 10
)
generic_event_args = {
"event_id": f"carry-{event.event_id}",
"coordinates": last_coord,
"team": next_event.team,
"player": next_event.player,
"ball_owning_team": next_event.ball_owning_team,
"ball_state": event.ball_state,
"period": next_event.period,
"timestamp": last_timestamp,
"raw_event": None,
}
carry_event_args = {
"result": CarryResult.COMPLETE,
"qualifiers": None,
"end_coordinates": new_coord,
"end_timestamp": next_event.timestamp,
}
new_carry = self.event_factory.build_carry(
**carry_event_args, **generic_event_args
)
dataset.records.insert(idx + idx_plus, new_carry)
return dataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod

from kloppy.domain import EventDataset


class SyntheticEventGenerator(ABC):
@abstractmethod
def add_synthetic_event(self, dataset: EventDataset) -> EventDataset:
raise NotImplementedError
127 changes: 127 additions & 0 deletions kloppy/tests/test_synthetic_event_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from datetime import timedelta

from kloppy.domain import (
EventType,
Unit,
)
from kloppy.utils import performance_logging
from kloppy import statsbomb, statsperform


class TestSyntheticEventGenerator:
""""""

def _load_dataset_statsperform(
self, base_dir, base_filename="statsperform"
):
return statsperform.load_event(
ma1_data=base_dir / f"files/{base_filename}_event_ma1.json",
ma3_data=base_dir / f"files/{base_filename}_event_ma3.json",
)

def _load_dataset_statsbomb(
self, base_dir, base_filename="statsbomb", event_types=None
):
return statsbomb.load(
event_data=base_dir / f"files/{base_filename}_event.json",
lineup_data=base_dir / f"files/{base_filename}_lineup.json",
event_types=event_types,
)

def calculate_carry_accuracy(
self, real_carries, generated_carries, real_carries_with_min_length
):
def is_match(real_carry, generated_carry):
return (
real_carry.player
and generated_carry.player
and real_carry.player.player_id
== generated_carry.player.player_id
and real_carry.period == generated_carry.period
and abs(real_carry.timestamp - generated_carry.timestamp)
< timedelta(seconds=5)
)

true_positives = 0
matched_real_carries = set()
for generated_carry in generated_carries:
for idx, real_carry in enumerate(real_carries):
if idx in matched_real_carries:
continue
if is_match(real_carry, generated_carry):
true_positives += 1
matched_real_carries.add(idx)
break

false_negatives = 0
matched_generated_carries = set()
for real_carry in real_carries_with_min_length:
found_match = False
for idx, generated_carry in enumerate(generated_carries):
if idx in matched_generated_carries:
continue
if is_match(real_carry, generated_carry):
found_match = True
matched_generated_carries.add(idx)
break
if not found_match:
false_negatives += 1

false_positives = len(generated_carries) - true_positives

accuracy = true_positives / (
true_positives + false_positives + false_negatives
)

print("TP:", true_positives)
print("FP:", false_positives)
print("FN:", false_negatives)
print("accuracy:", accuracy)

return accuracy

def test_synthetic_carry_generator(self, base_dir):
dataset_with_carries = self._load_dataset_statsbomb(base_dir)
pitch = dataset_with_carries.metadata.pitch_dimensions

min_length_meters = 3
max_length_meters = 60
max_duration = timedelta(seconds=10)

all_statsbomb_caries = dataset_with_carries.find_all("carry")
all_qualifying_statsbomb_queries = [
carry
for carry in all_statsbomb_caries
if (
min_length_meters
<= pitch.distance_between(
carry.coordinates, carry.end_coordinates, Unit.METERS
)
<= max_length_meters
and carry.end_timestamp - carry.timestamp < max_duration
)
]

dataset = self._load_dataset_statsbomb(
base_dir,
event_types=[
event.value for event in EventType if event.value != "CARRY"
],
)

with performance_logging("generating synthetic events"):
dataset = dataset.add_synthetic_event(
EventType.CARRY,
min_length_meters=min_length_meters,
max_length_meters=max_length_meters,
max_duration=max_duration,
)
all_carries = dataset.find_all("carry")
assert (
self.calculate_carry_accuracy(
all_statsbomb_caries,
all_carries,
all_qualifying_statsbomb_queries,
)
> 0.80
)