Skip to content

Commit

Permalink
Formatting + isort
Browse files Browse the repository at this point in the history
  • Loading branch information
vinhowe committed May 23, 2023
1 parent bebcdc5 commit 90c2181
Show file tree
Hide file tree
Showing 15 changed files with 25 additions and 20 deletions.
2 changes: 1 addition & 1 deletion example_configure_survey.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from lm_survey.survey.survey import Survey
import os

from lm_survey.survey.survey import Survey

if __name__ == "__main__":
survey_directory = os.path.join("data", "ATP", "American_Trends_Panel_W92")
Expand Down
1 change: 0 additions & 1 deletion lm_survey/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing


MULTIPLE_CHOICE_LIST: typing.List[str] = [
"A",
"B",
Expand Down
1 change: 0 additions & 1 deletion lm_survey/prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from lm_survey.constants import MULTIPLE_CHOICE_LIST


INDEPENDENT_VARIABLE_SUMMARY_TEMPLATE = """Context: {context_summary}
{dependent_variable_prompt}"""
Expand Down
4 changes: 2 additions & 2 deletions lm_survey/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lm_survey.samplers.base_sampler import BaseSampler
from lm_survey.samplers.async_openai_sampler import AsyncOpenAiSampler
from lm_survey.samplers.auto_sampler import AutoSampler
from lm_survey.samplers.base_sampler import BaseSampler
from lm_survey.samplers.hf_sampler import HfSampler
from lm_survey.samplers.openai_sampler import OpenAiSampler
from lm_survey.samplers.async_openai_sampler import AsyncOpenAiSampler
3 changes: 1 addition & 2 deletions lm_survey/samplers/async_openai_sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys

import openai
from openai.error import RateLimitError

import torch
from aiolimiter import AsyncLimiter
from openai.error import RateLimitError

from lm_survey.samplers.base_sampler import BaseSampler, MaybeAwaitable

Expand Down
4 changes: 2 additions & 2 deletions lm_survey/samplers/auto_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lm_survey.samplers.hf_sampler import HfSampler
from lm_survey.samplers.openai_sampler import OpenAiSampler
from lm_survey.samplers.async_openai_sampler import AsyncOpenAiSampler
from lm_survey.samplers.base_sampler import BaseSampler
from lm_survey.samplers.hf_sampler import HfSampler
from lm_survey.samplers.openai_sampler import OpenAiSampler


class AutoSampler(BaseSampler):
Expand Down
2 changes: 1 addition & 1 deletion lm_survey/samplers/base_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
import typing
from abc import ABCMeta, abstractmethod

T = typing.TypeVar("T")
MaybeAwaitable = typing.Union[T, typing.Awaitable[T]]
Expand Down
5 changes: 3 additions & 2 deletions lm_survey/samplers/hf_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import typing
from lm_survey.samplers.base_sampler import BaseSampler

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

from lm_survey.samplers.base_sampler import BaseSampler


class HfSampler(BaseSampler):
Expand Down
2 changes: 1 addition & 1 deletion lm_survey/samplers/openai_sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import typing

import openai
import tiktoken
import torch

from lm_survey.samplers.base_sampler import BaseSampler
import openai

OPENAI_TOKEN_COSTS = {
# cents per 1000 tokens
Expand Down
6 changes: 3 additions & 3 deletions lm_survey/survey/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from lm_survey.survey.dependent_variable_sample import (
DependentVariableSample,
Completion,
DependentVariableSample,
)
from lm_survey.survey.survey import Survey
from lm_survey.survey.variable import Variable
from lm_survey.survey.question import Question, ValidOption
from lm_survey.survey.results import SurveyResults
from lm_survey.survey.survey import Survey
from lm_survey.survey.variable import Variable
1 change: 1 addition & 0 deletions lm_survey/survey/dependent_variable_sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing

from lm_survey.survey.question import Question


Expand Down
2 changes: 2 additions & 0 deletions lm_survey/survey/question.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import pandas as pd

from lm_survey.constants import MULTIPLE_CHOICE_LIST
from lm_survey.prompt_templates import (
DEPENDENT_VARIABLE_TEMPLATE,
Expand Down
2 changes: 2 additions & 0 deletions lm_survey/survey/results.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import os
import typing

import pandas as pd
import pandas.core.groupby.generic

from lm_survey.survey.dependent_variable_sample import DependentVariableSample


Expand Down
8 changes: 4 additions & 4 deletions lm_survey/survey/survey.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import argparse
import functools
import json
import os
import typing
from pathlib import Path
Expand All @@ -6,16 +9,13 @@
import pandas as pd
from sklearn.metrics import normalized_mutual_info_score

from lm_survey.prompt_templates import INDEPENDENT_VARIABLE_SUMMARY_TEMPLATE
from lm_survey.survey.dependent_variable_sample import (
Completion,
DependentVariableSample,
)
from lm_survey.survey.question import Question, ValidOption
from lm_survey.survey.variable import Variable
from lm_survey.prompt_templates import INDEPENDENT_VARIABLE_SUMMARY_TEMPLATE
import json
import functools
import argparse


class Survey:
Expand Down
2 changes: 2 additions & 0 deletions lm_survey/survey/variable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import pandas as pd

from lm_survey.survey.question import Question


Expand Down

0 comments on commit 90c2181

Please sign in to comment.