Skip to content

Commit

Permalink
fix sampling error
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Oct 4, 2024
1 parent db8839e commit 0fcfcb0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 8 additions & 2 deletions sotopia/samplers/constraint_based_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,16 @@ def sample(
agents_which_fit_scenario: list[list[str]] = []

if self.env_candidates is None:
self.env_candidates = EnvironmentProfile.all()
env_candidates = EnvironmentProfile.all()
if not env_candidates:
raise ValueError("No environment candidates available for sampling.")
self.env_candidates = env_candidates

if self.agent_candidates is None:
self.agent_candidates = AgentProfile.all()
agent_candidates = AgentProfile.all()
if not agent_candidates:
raise ValueError("No agent candidates available for sampling.")
self.agent_candidates = agent_candidates

agent_candidate_ids: set[str] | None = None
if self.agent_candidates:
Expand Down
10 changes: 8 additions & 2 deletions sotopia/samplers/uniform_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,16 @@ def sample(
assert replacement, "Uniform sampling without replacement is not supported yet"

if self.env_candidates is None:
self.env_candidates = EnvironmentProfile.all()
env_candidates = EnvironmentProfile.all()
if not env_candidates:
raise ValueError("No environment candidates available for sampling.")
self.env_candidates = env_candidates

if self.agent_candidates is None:
self.agent_candidates = AgentProfile.all()
agent_candidates = AgentProfile.all()
if not agent_candidates:
raise ValueError("No agent candidates available for sampling.")
self.agent_candidates = agent_candidates

for _ in range(size):
env_profile = random.choice(self.env_candidates)
Expand Down

0 comments on commit 0fcfcb0

Please sign in to comment.