diff --git a/server/models/huggingface_models.py b/server/models/huggingface_models.py new file mode 100644 index 0000000..b442610 --- /dev/null +++ b/server/models/huggingface_models.py @@ -0,0 +1,89 @@ +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig +from transformers import AutoTokenizer, T5ForConditionalGeneration +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModel +from transformers import RobertaTokenizer, T5ForConditionalGeneration +from transformers import AutoModelForCausalLM +# from transformers import BitsAndBytesConfig +import argparse +import transformers +import torch +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +# nf4_config = BitsAndBytesConfig( +# load_in_4bit=True, +# bnb_4bit_quant_type="nf4", +# bnb_4bit_use_double_quant=True, +# bnb_4bit_compute_dtype=torch.bfloat16 +# ) + + +def read_from_file(path: str) -> str: + opened_file = open(path, "r") + data = opened_file.read() + return data + + +def codet5_base_model(text: str, max_len: int): + global device + tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base') + model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base').to(device) + input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) + generated_ids = model.generate(input_ids, max_length=max_len, do_sample=True, num_return_sequences=5).cpu() + print(generated_ids) + return map(lambda prompt_ans: tokenizer.decode(prompt_ans, skip_special_tokens=True), generated_ids) + + +def starcoder_model(text: str, max_len: int): + global device + checkpoint = "bigcode/starcoderbase-1b" + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) + inputs = tokenizer.encode(text, return_tensors="pt").to(device) + outputs = model.generate(inputs, max_length=max_len, do_sample=True, num_return_sequences=5).cpu() + print(outputs) + print("sent") + return map(lambda prompt_ans: tokenizer.decode(prompt_ans, skip_special_tokens=True), outputs) + + +def healing(tokenizer, model, prefix, outputs): + pass + + +def llama_model(text: str): + model_name = "codellama/CodeLlama-7b-hf" + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + # model = LlamaForCausalLM.from_pretrained(model_name, quantization_config=nf4_config) + model = LlamaForCausalLM.from_pretrained(model_name) + + pipeline = transformers.pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + torch_dtype=torch.float32, + device_map="auto", + ) + + sequences = pipeline( + text, + do_sample=True, + top_k=10, + temperature=0.1, + top_p=0.95, + num_return_sequences=1, + eos_token_id=tokenizer.eos_token_id, + max_length=200, + ) + for seq in sequences: + print(f"Result: {seq['generated_text']}") + + +if __name__ == "__main__": + # codet5_model("123") + # llama_model("def print_hello_world():") + # codet5_base_model("def print_hello_world():") + # codet5_small_model("def print_hello_world():") + # starcoder_model("def print_hello_world():") + # In idea, it might be run using command like this: python huggingface_models.py --model codellama/CodeLlama-7b-hf + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="codellama/CodeLlama-7b-hf") + args = parser.parse_args() diff --git a/server/server.py b/server/server.py index 84e7619..2ae9480 100644 --- a/server/server.py +++ b/server/server.py @@ -1,9 +1,13 @@ from http.server import HTTPServer, BaseHTTPRequestHandler import argparse import json -from models.huggingface_models import codet5_small_model +from models.huggingface_models import codet5_base_model, starcoder_model -models = [{"model_name": "StarCoder"}, {"model_name": "codellama"}, {"model_name": "codet5"}] +models = [{"model_name": "StarCoder"}, {"model_name": "codeT5-base"}] +maper = { + "StarCoder": starcoder_model, + "codeT5-base": codet5_base_model +} class RequestHandler(BaseHTTPRequestHandler): @@ -18,10 +22,13 @@ def do_POST(self): post_body = self.rfile.read(content_len) json_data = json.loads(post_body) text_received = json_data["prompt"] - processed_texts = codet5_small_model(text_received, json_data["max_new_tokens"]) + model = maper[json_data["model"]] + processed_texts = model(text_received, json_data["max_new_tokens"]) + start_index = len(text_received) if json_data["model"] == "StarCoder" else 0 json_bytes = json.dumps( - {"results" : [{"text": text_received}, {"text": processed_texts}]} - ).encode("utf-8") + {"results" : list(map(lambda x: {"text": x[start_index:]}, processed_texts))} + ).encode("utf-8") + print(json_bytes) self.wfile.write(json_bytes) def do_GET(self): @@ -33,7 +40,6 @@ def do_GET(self): def run(port, addr): server_address = (addr, port) httpd = HTTPServer(server_address, RequestHandler) - print(f"Starting httpd server on {addr}:{port}") httpd.serve_forever()