Skip to content

Commit ef9a7bf

Browse files
committed
llava_agent: Add frameskipping and cleanup
1 parent 2442b9d commit ef9a7bf

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

llava_agent.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ def parse_args():
1717

1818
parser.add_argument("--frames", type=str, default="frames")
1919
parser.add_argument("--train-mins", type=int, default=60)
20-
parser.add_argument("--log", type=str, default="llm_agent_"+str(uuid4())+".csv")
20+
parser.add_argument("--log", type=str,
21+
default="llm_agent_"+str(uuid4())+".csv")
2122

2223
return parser.parse_args()
2324

25+
2426
def obs_to_bytes(observation):
2527
"""Converts an observation encoded as a numpy array into a bytes representation of a PNG image."""
2628
image = Image.fromarray(observation)
@@ -35,23 +37,25 @@ def obs_to_bytes(observation):
3537
args = parse_args()
3638

3739
config = dict(
38-
max_block_generate_distance=3, # 16x3 blocks
40+
max_block_generate_distance=3, # 16x3 blocks
3941
# hud_scaling=0.9,
4042
fov=90,
4143
console_alpha=0,
4244
smooth_lighting=False,
4345
performance_tradeoffs=True,
4446
enable_particles=False,
4547
)
46-
4748
env = gym.make(
4849
"Craftium/OpenWorld-v0",
49-
frameskip=1,
50+
frameskip=3,
5051
obs_width=520,
5152
obs_height=520,
5253
# render_mode="human",
5354
# pipe_proc=False,
5455
minetest_conf=config,
56+
sync_mode=True,
57+
# max_fps=60,
58+
pmul=20,
5559
)
5660

5761
observation, info = env.reset()
@@ -62,7 +66,8 @@ def obs_to_bytes(observation):
6266
"select hotbar slot 1", "select hotbar slot 2", "select hotbar slot 3", "select hotbar slot 4", "select hotbar slot 5",
6367
"move camera right", "move camera left", "move camera up", "move camera down"]
6468

65-
objectives = ["is to chop a tree", "is to collect stone", "is to collect iron", "is to find diamond blocks"]
69+
objectives = ["is to chop a tree", "is to collect stone",
70+
"is to collect iron", "is to find diamond blocks"]
6671
obj_rwds = [128, 256, 1024, 2048]
6772
objective_id = 0
6873

@@ -72,7 +77,8 @@ def obs_to_bytes(observation):
7277
episode = 0
7378
while (time.time() - start) / 60 < args.train_mins:
7479
img_bytes, img = obs_to_bytes(observation)
75-
img.save(os.path.join(args.frames, f"frame_{str(t_step).zfill(7)}.png"), "PNG")
80+
img.save(os.path.join(
81+
args.frames, f"frame_{str(t_step).zfill(7)}.png"), "PNG")
7682

7783
prompt = f"You are a reinforcement learning agent in the Minecraft game. You will be presented the current observation, and you have to select the next action with the ultimate objective to fulfill your goal. In this case, the goal {objectives[objective_id]}. You should fight monsters and hunt animals just as a secondary objective and survival. Available actions are: do nothing, move forward, move backward, move left, move right, jump, sneak, use tool, place, select hotbar slot 1, select hotbar slot 2, select hotbar slot 3, select hotbar slot 4, select hotbar slot 5, move camera right, move camera left, move camera up, move camera down. From now on, your responses must only contain the name of the action you will take, nothing else."
7884
print("Prompt:", prompt)
@@ -86,7 +92,7 @@ def obs_to_bytes(observation):
8692
incorrect = False
8793
candidates = [i for i, name in enumerate(act_names) if name in act_str]
8894
print(candidates)
89-
if len(candidates) == 0: # if the response is in an incorrect format
95+
if len(candidates) == 0: # if the response is in an incorrect format
9096
action = env.action_space.sample() # take a random action
9197
incorrect = True
9298
print("[WARNING] Incorrect action. Using random action.")
@@ -98,13 +104,16 @@ def obs_to_bytes(observation):
98104
print(f"* Action: {action}")
99105

100106
if act_names[action] == "jump":
107+
# jump forward
101108
_, _, _, _, _ = env.step(action)
102109
observation, reward, terminated, truncated, _info = env.step(1)
103110
else:
104-
observation, reward, terminated, truncated, _info = env.step(action)
111+
observation, reward, terminated, truncated, _info = env.step(
112+
action)
105113

106114
ep_ret += reward
107-
print(f"Step: {t_step}, Elapsed: {int(time.time()-start)}s, Reward: {reward}, Ep. ret.: {ep_ret}")
115+
print(
116+
f"Step: {t_step}, Elapsed: {int(time.time()-start)}s, Reward: {reward}, Ep. ret.: {ep_ret}")
108117

109118
# check if a stage has been completed
110119
if reward >= 128.0:
@@ -116,8 +125,10 @@ def obs_to_bytes(observation):
116125

117126
with open(args.log, "a" if t_step > 0 else "w") as f:
118127
if t_step == 0:
119-
f.write("t_step,episode,elapsed mins,reward,ep_ret,objective_id,id\n")
120-
f.write(f"{t_step},{episode},{(time.time()-start)/60},{reward},{ep_ret},{objective_id},{args.log}\n")
128+
f.write(
129+
"t_step,episode,elapsed mins,reward,ep_ret,objective_id,id\n")
130+
f.write(
131+
f"{t_step},{episode},{(time.time()-start)/60},{reward},{ep_ret},{objective_id},{args.log}\n")
121132

122133
if terminated or truncated:
123134
episode += 1

0 commit comments

Comments
 (0)