Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: adding precommit config configuration #104

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
types: [python]
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
22 changes: 12 additions & 10 deletions api/codegeex-api-example-python/generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,31 @@

import requests

'''
"""
Code Generation
'''
"""
API_KEY = "" # Get from Tianqi console. 从控制台获取
API_SECRET = "" # Get from Tianqi console. 从控制台获取
PROMPT = "from typing import List\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n " \
"\"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given " \
"threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements(" \
"[1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n"
PROMPT = (
"from typing import List\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n "
'""" Check if in given list of numbers, are any two numbers closer to each other than\n given '
"threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements("
'[1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n'
)
NUMBER = 3
LANG = "Python"
request_url = "https://tianqi.aminer.cn/api/v2/"
api = 'multilingual_code_generate'
api = "multilingual_code_generate"

# Request is in json format. 指定请求参数格式为json
headers = {'Content-Type': 'application/json'}
headers = {"Content-Type": "application/json"}
request_url = request_url + api
data = {
"apikey": API_KEY,
"apisecret": API_SECRET,
"prompt": PROMPT,
"n": NUMBER,
"lang": LANG
"lang": LANG,
}


Expand All @@ -36,5 +38,5 @@ def main():
print(response.json())


if __name__ == '__main__':
if __name__ == "__main__":
main()
23 changes: 14 additions & 9 deletions codegeex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def get_model(


def generate(
model,
tokenizer: CodeGeeXTokenizer,
prompt: str,
model,
tokenizer: CodeGeeXTokenizer,
prompt: str,
out_seq_length: int,
seq_length: int = 2048,
top_k: int = 0,
Expand All @@ -32,7 +32,7 @@ def generate(
if verbose:
print(f"Current prompt:\n{prompt}")
print("N_token_prompt:", n_token_prompt)

generated_codes = []
if backend == "megatron":
token_stream = get_token_stream(
Expand All @@ -53,17 +53,22 @@ def generate(
for j in range(micro_batch_size):
if is_finished[j]:
continue

if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length:

if (
generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id
or len(generated_tokens[j]) >= out_seq_length
):
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = tokenizer.decode_code(
generated_tokens_[n_token_prompt:]
)
generated_code = "".join(generated_code)
generated_codes.append(generated_code)
if verbose:
print(f"\nGenerated code {i}:\n{generated_code}")

if all(is_finished):
break

return generated_codes
return generated_codes
90 changes: 62 additions & 28 deletions codegeex/benchmark/evaluate_humaneval_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from codegeex.benchmark.execution import check_correctness

LANGUAGE_NAME = {
"cpp" : "CPP",
"go" : "Go",
"java" : "Java",
"js" : "JavaScript",
"cpp": "CPP",
"go": "Go",
"java": "Java",
"js": "JavaScript",
"python": "Python",
}

Expand All @@ -29,7 +29,11 @@ def process_humaneval_test(sample, problems, example_test=False):
language = task_id.split("/")[0].lower()

prompt = sample["prompt"]
if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "":
if (
example_test
and "example_test" in problems[task_id]
and problems[task_id]["example_test"] != ""
):
test = problems[task_id]["example_test"]
else:
test = problems[task_id]["test"]
Expand All @@ -39,7 +43,7 @@ def process_humaneval_test(sample, problems, example_test=False):
if language == "python":
code_ = []
for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'):
if len(line.strip()) > 0 and line[0] != " " and line[0] != "\t":
break
code_.append(line)
code = "\n".join(code_)
Expand Down Expand Up @@ -68,10 +72,21 @@ def process_humaneval_test(sample, problems, example_test=False):
if pkg not in test_setup:
p = pkg.split("/")[-1]
if p + "." in code:
other_pkgs.append(f"\"{pkg}\"")
other_pkgs.append(f'"{pkg}"')
if other_pkgs:
import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")"
test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
import_other_pkgs = (
"import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")"
)
test_string = (
test_setup
+ "\n"
+ import_other_pkgs
+ "\n"
+ prompt
+ code
+ "\n"
+ test
)
else:
test_string = test_setup + "\n" + prompt + code + "\n" + test
elif language == "rust":
Expand All @@ -97,21 +112,20 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:


def evaluate_functional_correctness(
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz",
out_dir: str = None,
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz",
out_dir: str = None,
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
):
if example_test:
print("Example test...")

problems = read_dataset(problem_file,
dataset_type="humaneval")
problems = read_dataset(problem_file, dataset_type="humaneval")
sample_jsonl = stream_jsonl_all(input_file)

if example_test:
Expand All @@ -121,7 +135,9 @@ def evaluate_functional_correctness(
if out_dir is not None:
if not os.path.exists(out_dir):
os.makedirs(out_dir)
out_file = os.path.join(out_dir, input_file.split('/')[-1].replace(".jsonl", suffix))
out_file = os.path.join(
out_dir, input_file.split("/")[-1].replace(".jsonl", suffix)
)
else:
out_file = os.path.join(input_file.replace(".jsonl", suffix))

Expand Down Expand Up @@ -149,10 +165,19 @@ def evaluate_functional_correctness(
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
sample["generation"] = sample["canonical_solution"]
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
sample["test_code"] = process_humaneval_test(
sample, problems, example_test
)
if sample["test_code"] is None:
continue
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id])
args = (
task_id,
sample,
lang,
timeout,
tmp_dir_,
completion_id[task_id],
)
future = executor.submit(check_correctness, *args)
futures.append(future)
completion_id[task_id] += 1
Expand All @@ -164,7 +189,11 @@ def evaluate_functional_correctness(
lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
lang = (
regex.findall("-to-.*-", input_file)[0]
.split("-to-")[-1]
.rstrip("-")
)
for l in LANGUAGE_NAME:
if l in lang:
lang = l
Expand All @@ -174,7 +203,9 @@ def evaluate_functional_correctness(
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
sample["task_id"] = task_id
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
sample["test_code"] = process_humaneval_test(
sample, problems, example_test
)
if sample["test_code"] is None:
continue
if "completion_id" in sample:
Expand Down Expand Up @@ -208,8 +239,11 @@ def evaluate_functional_correctness(
correct = np.array(correct)
if evaluate_pass_at_k:
ks = k
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
for k in ks if (total >= k).all()}
pass_at_k = {
f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
for k in ks
if (total >= k).all()
}
print(pass_at_k)
else:
print("Total:", np.sum(total))
Expand All @@ -222,7 +256,7 @@ def evaluate_functional_correctness(
for r in res:
fp.write((json.dumps(r[1]) + "\n").encode("utf-8"))
else:
fp = open(out_file, 'w')
fp = open(out_file, "w")
for res in results.values():
for r in res:
fp.write(json.dumps(r[1]) + "\n")
Expand Down
Loading