-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerator.py
More file actions
24 lines (19 loc) · 887 Bytes
/
generator.py
File metadata and controls
24 lines (19 loc) · 887 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import pickle as pkl
import tensorflow as tf
import numpy as np
import os
class generator():
def __init__(self, split, data_path):
self.split = split
self.data_path = data_path
self.max_idx = len(os.listdir(os.path.join(self.data_path, self.split)))
def data_generator(self, ):
for i in range(self.max_idx):
with open(os.path.join(self.data_path, self.split, f'g_{i}.pkl'), 'rb') as file:
X, y, y_mask, adj, adj_mean, adj_norm = pkl.load(file)
yield X, y, y_mask, adj, adj_mean, adj_norm
def tf_generator(self, generator):
tf_gen = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
tf_gen = tf_gen.prefetch(tf.data.experimental.AUTOTUNE)
tf_gen = tf_gen.take(self.max_idx)
return tf_gen