Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ py_test(
":operations",
":options",
":samplers",
"//grain",
"//grain/_src/core:sharding",
"//grain/_src/core:transforms",
"//grain/_src/python/dataset",
Expand Down
10 changes: 8 additions & 2 deletions grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

from collections.abc import Sequence
import functools
from multiprocessing import shared_memory
import pathlib
import platform
import sys
import threading
from typing import Any, Union
from unittest import mock

Expand All @@ -30,17 +32,17 @@
from grain._src.python import data_loader as data_loader_lib
from grain._src.python import options
from grain._src.python import samplers
# pylint: disable=g-importing-member
from grain._src.python.data_sources import ArrayRecordDataSource
from grain._src.python.data_sources import RangeDataSource
from grain._src.python.data_sources import SharedMemoryDataSource
from grain._src.python.dataset.transformations import batch
from grain._src.python.dataset.transformations import process_prefetch
from grain._src.python.dataset.transformations import source
from grain._src.python.ipc import shared_memory_array
from grain._src.python.operations import BatchOperation
from grain._src.python.operations import FilterOperation
from grain._src.python.operations import MapOperation
from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
# pylint: enable=g-importing-member
import numpy as np
import parameterized

Expand Down Expand Up @@ -186,6 +188,10 @@ def test_copy_skipped_flags_c_contiguous(self):
{"num_threads_per_worker": 15},
])
class DataLoaderTest(absl_parameterized.TestCase):

def tearDown(self):
super().tearDown()

# Number of prefetch threads for each Grain worker
num_threads_per_worker: int | None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from absl import flags
import cloudpickle
from grain._src.core import monitoring as grain_monitoring
from grain._src.core.config import config
import multiprocessing as mp
from grain._src.python import grain_logging
Expand Down Expand Up @@ -379,7 +378,7 @@ def _stats(self):
)

# pytype: enable=attribute-error
# pylint: enable=protected-access
# pylint: disable=protected-access

def start_prefetch(self) -> None:
"""Starts prefetching elements in background.
Expand Down
Loading