Skip to content

Commit

Permalink
implemented labels condition for dictator game
Browse files Browse the repository at this point in the history
  • Loading branch information
phelps-sg committed Nov 23, 2023
1 parent 2545104 commit 17af3b3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
38 changes: 33 additions & 5 deletions llm_cooperation/experiments/dictator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
round_instructions,
run_and_record_experiment,
)
from llm_cooperation.experiments.dilemma import CONDITION_LABEL, Label
from llm_cooperation.gametypes.oneshot import OneShotResults, run_experiment

TOTAL_SHARE = 4
Expand Down Expand Up @@ -74,18 +75,41 @@ def project(color: str) -> str:
DictatorEnum.WHITE: "white",
}

reverse_color_mappings: Dict[str, DictatorEnum] = {
value: key for key, value in color_mappings.items()
numeral_mappings: Dict[DictatorEnum, str] = {
DictatorEnum.BLACK: "1",
DictatorEnum.BROWN: "2",
DictatorEnum.GREEN: "3",
DictatorEnum.BLUE: "4",
DictatorEnum.WHITE: "5",
}

number_mappings: Dict[DictatorEnum, str] = {
DictatorEnum.BLACK: "one",
DictatorEnum.BROWN: "two",
DictatorEnum.GREEN: "three",
DictatorEnum.BLUE: "four",
DictatorEnum.WHITE: "five",
}


def mappings_for(participant: Participant) -> Dict[DictatorEnum, str]:
label_type = participant[CONDITION_LABEL]
if label_type == Label.COLORS.value:
return color_mappings
elif label_type == Label.NUMBERS.value:
return number_mappings
elif label_type == Label.NUMERALS.value:
return numeral_mappings
raise ValueError(f"Unrecognized value {participant[CONDITION_LABEL]}")


@dataclass
class DictatorChoice:
value: DictatorEnum

# pylint: disable=unused-argument
def description(self, participant_condition: Participant) -> str:
return project(color_mappings[self.value])
return project(mappings_for(participant_condition)[self.value])

@property
def donation(self) -> float:
Expand All @@ -111,14 +135,18 @@ def payoff_allo(self) -> float:

DICTATOR_ATTRIBUTES: Grid = {
CONDITION_CHAIN_OF_THOUGHT: [True, False],
# CONDITION_LABEL: all_values(Label),
CONDITION_LABEL: all_values(Label),
CONDITION_CASE: all_values(Case),
CONDITION_PRONOUN: all_values(Pronoun),
CONDITION_DEFECT_FIRST: [True, False],
# CONDITION_LABELS_REVERSED: [True, False],
}


def inverted(mappings: Dict[DictatorEnum, str]) -> Dict[str, DictatorEnum]:
return {value: key for key, value in mappings.items()}


def payout_ego(choice: DictatorChoice) -> str:
return amount_as_str(choice.payoff_ego)

Expand Down Expand Up @@ -182,7 +210,7 @@ def extract_choice_dictator(
match = re.search(r"choice:\s*(.*)", text)
if match:
choice = match.group(1)
for key, value in reverse_color_mappings.items():
for key, value in inverted(mappings_for(participant)).items():
if key in choice:
return DictatorChoice(value)
raise ValueError(f"Cannot determine choice from {completion}")
Expand Down
31 changes: 17 additions & 14 deletions tests/test_dictator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,30 @@ def test_dictator_choice(


@pytest.mark.parametrize(
"text, expected_result",
"condition, text, expected_result",
[
("Choice: Black", BLACK),
("Choice: 'project black'", BLACK),
("choice: 'Project BLACK'", BLACK),
("choice:Brown", BROWN),
("choice: Green", GREEN),
("Choice: Blue", BLUE),
("Choice: White", WHITE),
(lazy_fixture("base_condition"), "Choice: Black", BLACK),
(lazy_fixture("base_condition"), "Choice: 'project black'", BLACK),
(lazy_fixture("base_condition"), "choice: 'Project BLACK'", BLACK),
(lazy_fixture("base_condition"), "choice:Brown", BROWN),
(lazy_fixture("base_condition"), "choice: Green", GREEN),
(lazy_fixture("base_condition"), "Choice: Blue", BLUE),
(lazy_fixture("base_condition"), "Choice: White", WHITE),
(lazy_fixture("with_numerals"), "Choice: 5", WHITE),
(lazy_fixture("with_numbers"), "Choice: four", BLUE),
(lazy_fixture("with_numbers"), "Choice: Three", GREEN),
],
)
def test_extract_choice_dictator(
text: str, expected_result: DictatorChoice, base_condition: Participant
condition: Participant, text: str, expected_result: DictatorChoice
):
assert extract_choice_dictator(condition, user_message(text)) == expected_result
assert (
extract_choice_dictator(base_condition, user_message(text)) == expected_result
)
assert (
extract_choice_dictator(base_condition, user_message(text.upper()))
extract_choice_dictator(condition, user_message(text.upper()))
== expected_result
)
assert (
extract_choice_dictator(base_condition, user_message(text.lower()))
extract_choice_dictator(condition, user_message(text.lower()))
== expected_result
)

Expand Down Expand Up @@ -138,6 +139,8 @@ def test_compute_freq_dictator(test_choice: DictatorChoice):
lazy_fixture("with_gender_neutral_pronoun"),
lazy_fixture("with_upper_case"),
lazy_fixture("with_chain_of_thought"),
lazy_fixture("with_numerals"),
lazy_fixture("with_numbers"),
],
)
def test_get_prompt_dictator(condition: Participant):
Expand Down

0 comments on commit 17af3b3

Please sign in to comment.