Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add batch environment #146

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e7f4a8a
add StatsBase package
Sid-Bhatia-0 Jun 15, 2021
fd659a7
add SingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 15, 2021
0b8c1cb
add playability to SingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 16, 2021
1bb2641
add keyword force to RLBase.reset! method
Sid-Bhatia-0 Jun 16, 2021
0d115af
fix replay method
Sid-Bhatia-0 Jun 16, 2021
d39fae7
write characters to terminal out while playing
Sid-Bhatia-0 Jun 16, 2021
20d103f
ignore scratchpad.jl
Sid-Bhatia-0 Jun 16, 2021
4c16b20
add tests for SingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 16, 2021
342deb0
set state style to internal state, copy reward & done
Sid-Bhatia-0 Jun 17, 2021
4926c99
add benchmark_multi_threaded.jl
Sid-Bhatia-0 Jun 17, 2021
7824551
print NUM_ENVS
Sid-Bhatia-0 Jun 17, 2021
d711bcc
move num_envs to the last dimension
Sid-Bhatia-0 Jun 21, 2021
2d68a0a
update tests for batch envs
Sid-Bhatia-0 Jun 21, 2021
3dbd497
don't copy tile_map, reward, and done in RLBase API
Sid-Bhatia-0 Jun 21, 2021
8ec31a0
remove unnecessary RLBase.DefaultPlayer
Sid-Bhatia-0 Jun 21, 2021
73406ea
rename benchmark_multi_threaded.jl to benchmark_batch.jl
Sid-Bhatia-0 Jun 21, 2021
45bf86a
fix and cleanup benchmark_batch
Sid-Bhatia-0 Jun 21, 2021
ed2b37b
make move function type stable (huge improvement in performance)
Sid-Bhatia-0 Jun 21, 2021
2ceedad
add function sample_two_positions_without_replacement
Sid-Bhatia-0 Jun 21, 2021
2426955
add DataStructures package in benchmarking code
Sid-Bhatia-0 Jun 24, 2021
cda573e
add ACTION_NAMES in ModuleSingleRoomUndirectedBatch
Sid-Bhatia-0 Jun 24, 2021
26d1231
refactor benchmark_batch.jl
Sid-Bhatia-0 Jun 24, 2021
bb048cc
rename benchmark_batch.jl to benchmark_utils.jl
Sid-Bhatia-0 Jun 24, 2021
e8854e2
add SingleRoomUndirected
Sid-Bhatia-0 Jun 24, 2021
9f3b1bd
ignore generated benchmark files
Sid-Bhatia-0 Jun 24, 2021
b730958
add benchmarking for non-batch envs
Sid-Bhatia-0 Jun 24, 2021
e2f313f
make SingleRoomUndirected mutable and improve performance
Sid-Bhatia-0 Jun 24, 2021
6d114c6
remove constants NUM_RESETS, STEPS_PER_EPISODE, NUM_ENVS
Sid-Bhatia-0 Jun 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ Manifest.toml
# vim temporary files
*~
*.swp

/src/scratchpad.jl
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Crayons = "4.0"
Expand Down
143 changes: 143 additions & 0 deletions benchmark/benchmark_multi_threaded.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import GridWorlds as GW
import ReinforcementLearningBase as RLBase
import BenchmarkTools as BT
import Dates

const STEPS_PER_RESET = 100
const NUM_RESETS = 100
const NUM_ENVS = 64

const information = Dict()

ENVS = [GW.ModuleSingleRoomUndirectedBatch.SingleRoomUndirectedBatch]

function run_random_policy!(env, num_resets, steps_per_reset)
num_envs = size(env.tile_map, 1)
action = Array{eltype(RLBase.action_space(env))}(undef, num_envs)
for _ in 1:num_resets
RLBase.reset!(env, force = true)
for _ in 1:steps_per_reset
state = RLBase.state(env)
for i in 1:num_envs
action[i] = rand(RLBase.action_space(env))
end
env(action)
is_terminated = RLBase.is_terminated(env)
reward = RLBase.reward(env)
end
end

return nothing
end

function format_benchmark(str::String)
l = split(str, "\n")
deleteat!(l, (1, 4, 9))
return strip.(l)
end

function write_benchmarks(information, file)
io = open(file, "w")

write(io, "Date: " * Dates.format(Dates.now(), "yyyy_mm_dd_HH_MM_SS") * "\n")
write(io, "# List of Environments\n")

for Env in ENVS
name = Env.body.body.body.name.name
write(io, " 1. [$(String(name))](#$(lowercase(String(name))))\n")
end

write(io, "\n")
write(io, "# Benchmarks\n\n")

for Env in ENVS
name = Env.body.body.body.name.name
env_benchmark = information[name]

write(io, "# $(String(name))\n\n")

write(io, "#### Run uniformly random policy, NUM_ENVS = $(NUM_ENVS), NUM_RESETS = $(NUM_RESETS), STEPS_PER_RESET = $(STEPS_PER_RESET), TOTAL_STEPS = $(NUM_RESETS * STEPS_PER_RESET)\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:run_random_policy]))
write(io, line * "\n\n")
end

write(io, "#### $(String(Symbol(Env)))()\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:instantiation]))
write(io, line * "\n\n")
end

write(io, "#### RLBase.reset!(env)\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:reset!]))
write(io, line * "\n\n")
end

write(io, "#### RLBase.state(env)\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:state]))
write(io, line * "\n\n")
end

write(io, "#### RLBase.action_space(env)\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:action_space]))
write(io, line * "\n\n")
end

write(io, "#### RLBase.is_terminated(env)\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:is_terminated]))
write(io, line * "\n\n")
end

write(io, "#### RLBase.reward(env)\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:reward]))
write(io, line * "\n\n")
end

for action in keys(env_benchmark[:action_info])
write(io, "#### env($action)\n\n")
for line in format_benchmark(repr("text/plain", env_benchmark[:action_info][action]))
write(io, line * "\n\n")
end
end

end

close(io)
end

# compile everything once
for Env in ENVS
env = Env(num_envs = NUM_ENVS)
run_random_policy!(env, NUM_RESETS, STEPS_PER_RESET)
end

@info "First run (for compilation) is complete"

for Env in ENVS

env = Env(num_envs = NUM_ENVS)

env_benchmark = Dict()

env_benchmark[:run_random_policy] = BT.@benchmark run_random_policy!($(Ref(env))[], $(Ref(NUM_RESETS))[], $(Ref(STEPS_PER_RESET))[])

env_benchmark[:instantiation] = BT.@benchmark $(Ref(Env))[](num_envs = $(NUM_ENVS)[])

env_benchmark[:reset!] = BT.@benchmark RLBase.reset!($(Ref(env))[], force = true)
env_benchmark[:state] = BT.@benchmark RLBase.state($(Ref(env))[])
env_benchmark[:action_space] = BT.@benchmark RLBase.action_space($(Ref(env))[])
env_benchmark[:is_terminated] = BT.@benchmark RLBase.is_terminated($(Ref(env))[])
env_benchmark[:reward] = BT.@benchmark RLBase.reward($(Ref(env))[])

action_info = Dict()
for action in RLBase.action_space(env)
actions = fill(action, NUM_ENVS)
action_info[Symbol(action)] = BT.@benchmark $(Ref(env))[]($(Ref(actions))[])
end
env_benchmark[:action_info] = action_info

name = Env.body.body.body.name.name
information[name] = env_benchmark

@info "$(name) benchmark complete"
end

write_benchmarks(information, "benchmark_multi_threaded.md")
1 change: 1 addition & 0 deletions src/GridWorlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("actions.jl")
include("objects.jl")
include("grid_world_base.jl")
include("abstract_grid_world.jl")
include("play.jl")
include("envs/envs.jl")
include("textual_rendering.jl")

Expand Down
1 change: 1 addition & 0 deletions src/envs/envs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ include("snake.jl")
include("catcher.jl")
include("transport.jl")
include("collect_gems_undirected_multi_agent.jl")
include("single_room_undirected_batch.jl")
Loading