diff --git a/tools/analysis_3d/analysis_runner.py b/tools/analysis_3d/analysis_runner.py index d970d75d2..367f3728e 100644 --- a/tools/analysis_3d/analysis_runner.py +++ b/tools/analysis_3d/analysis_runner.py @@ -178,7 +178,9 @@ def _extra_scenario_data( t4 = Tier4(data_root=str(scene_root_dir_path), verbose=False) sample_data = self._extract_sample_data(t4=t4) - scenario_data[scene_token] = ScenarioData(scene_token=scene_token, sample_data=sample_data) + scenario_data[scene_token] = ScenarioData( + scene_token=scene_token, frames=len(t4.sample), sample_data=sample_data + ) return scenario_data def run(self) -> None: diff --git a/tools/analysis_3d/callbacks/category.py b/tools/analysis_3d/callbacks/category.py index abf5d4e2c..0ac1375ae 100644 --- a/tools/analysis_3d/callbacks/category.py +++ b/tools/analysis_3d/callbacks/category.py @@ -41,6 +41,7 @@ def __init__( def _visualize_total_category_counts( self, dataset_category_counts: Dict[str, Dict[str, int]], + dataset_frame_counts: Dict[str, int], split_name: str, log_scale: bool = True, figsize: tuple[int, int] = (15, 15), @@ -52,6 +53,8 @@ def _visualize_total_category_counts( :param log_scale: Set True to make the frequency in log-scale (power of 10). :param figsize: Figure size. """ + frames = sum(dataset_frame_counts.values()) + all_available_categories = [ category_name for category_counts in dataset_category_counts.values() @@ -82,7 +85,7 @@ def _visualize_total_category_counts( # Add some text for labels, title and custom x-axis tick labels, etc. ax.set_ylabel(self.y_axis_label) - ax.set_title(self.x_axis_label) + ax.set_title(f"{self.x_axis_label} (Total frames: {frames})") ax.set_yticks(y + height, all_available_categories) ax.legend(loc=self.legend_loc) ax.invert_yaxis() @@ -101,6 +104,7 @@ def run(self, dataset_split_analysis_data: Dict[DatasetSplitName, AnalysisData]) print_log(f"Running {self.__class__.__name__}") for split_option in SplitOptions: dataset_category_counts = {} + dataset_frame_counts = {} for dataset_split_name, analysis_data in dataset_split_analysis_data.items(): split_name = dataset_split_name.split_name if split_name != split_option.value: @@ -111,7 +115,11 @@ def run(self, dataset_split_analysis_data: Dict[DatasetSplitName, AnalysisData]) remapping_classes=self.remapping_classes ) + dataset_frame_counts[dataset_name] = analysis_data.frames + self._visualize_total_category_counts( - dataset_category_counts=dataset_category_counts, split_name=split_option.value + dataset_category_counts=dataset_category_counts, + dataset_frame_counts=dataset_frame_counts, + split_name=split_option.value, ) print_log(f"Done running {self.__class__.__name__}") diff --git a/tools/analysis_3d/callbacks/category_attribute.py b/tools/analysis_3d/callbacks/category_attribute.py index bbb1c05c2..581f27949 100644 --- a/tools/analysis_3d/callbacks/category_attribute.py +++ b/tools/analysis_3d/callbacks/category_attribute.py @@ -37,6 +37,7 @@ def run(self, dataset_split_analysis_data: Dict[DatasetSplitName, AnalysisData]) print_log(f"Running {self.__class__.__name__}") for split_option in SplitOptions: dataset_category_counts = {} + dataset_frame_counts = {} for dataset_split_name, analysis_data in dataset_split_analysis_data.items(): split_name = dataset_split_name.split_name if split_name != split_option: @@ -47,7 +48,11 @@ def run(self, dataset_split_analysis_data: Dict[DatasetSplitName, AnalysisData]) remapping_classes=self.remapping_classes, category_name=self.category_name ) + dataset_frame_counts[dataset_name] = analysis_data.frames + self._visualize_total_category_counts( - dataset_category_counts=dataset_category_counts, split_name=split_option.value + dataset_category_counts=dataset_category_counts, + dataset_frame_counts=dataset_frame_counts, + split_name=split_option.value, ) print_log(f"Done running {self.__class__.__name__}") diff --git a/tools/analysis_3d/data_classes.py b/tools/analysis_3d/data_classes.py index cb30e02ea..1385fed39 100644 --- a/tools/analysis_3d/data_classes.py +++ b/tools/analysis_3d/data_classes.py @@ -152,6 +152,7 @@ class ScenarioData: """Data class to save data for a scenario, for example, a list of SampleData.""" scene_token: str + frames: int sample_data: Dict[str, SampleData] = field(default_factory=lambda: {}) # Sample token, SampleAnalysis def add_sample_data(self, sample_data: SampleData) -> None: @@ -262,3 +263,11 @@ def aggregate_category_attr_counts( for name, counts in category_counts.items(): total_category_counts[name] += counts return total_category_counts + + @property + def frames(self) -> int: + """ + Get total frames in this AnalysisData. + :return: Total frames in this AnalysisData. + """ + return sum(scenario_data.frames for scenario_data in self.scenario_data.values())