forked from flatland-association/flatland-baselines
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_solution.py
More file actions
67 lines (54 loc) · 2.61 KB
/
run_solution.py
File metadata and controls
67 lines (54 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from flatland.envs.rail_env import RailEnv
from flatland.evaluators.client import FlatlandRemoteClient
from flatland_baselines.deadlock_avoidance_heuristic.observation.full_env_observation import FullEnvObservation
from flatland_baselines.deadlock_avoidance_heuristic.policy.deadlock_avoidance_policy import DeadLockAvoidancePolicy
from flatland_baselines.deadlock_avoidance_heuristic.utils.progress_bar import ProgressBar
remote_client = FlatlandRemoteClient()
def main(debug: bool = True):
episode = 0
while True:
print("/ start DeadLockAvoidancePolicy", flush=True)
episode += 1
print("[INFO] EPISODE_START : {}".format(episode))
# NO WAY TO CHECK service/self.evaluation_done in client
# ------------------- user code -------------------------
my_observation_builder = FullEnvObservation()
# ---------------------------------------------------------
observations, info = remote_client.env_create(obs_builder_object=my_observation_builder)
# ------------------- user code -------------------------
env: RailEnv = remote_client.env
flatlandSolver = DeadLockAvoidancePolicy()
# ---------------------------------------------------------
if isinstance(observations, bool):
if not observations:
"""
The remote env returns False as the first obs
when it is done evaluating all the individual episodes
"""
print("[INFO] DONE ALL, BREAKING")
break
pbar = ProgressBar()
total_reward = 0
nbr_done = 0
while True:
try:
actions = flatlandSolver.act_many(env.get_agent_handles(), observations)
observations, all_rewards, done, info = remote_client.env_step(actions)
total_reward += sum(list(all_rewards.values()))
if env._elapsed_steps < env._max_episode_steps:
nbr_done = sum(list(done.values())[:-1])
except Exception as e:
print("[ERR] DONE BUT step() CALLED")
raise e
if debug:
if done['__all__']:
pbar.console_print(nbr_done, env.get_num_agents(), 'Nbr of done agents: {}'.format(len(done) - 1), '')
if done['__all__']:
print("[INFO] TOTAL_REW: ", total_reward)
print("[INFO] EPISODE_DONE : ", episode)
break
print("Evaluation Complete...")
print(remote_client.submit())
print("\\ end random_agent", flush=True)
if __name__ == '__main__':
main()