Skip to content

Commit

Permalink
Modified the prompt inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
kalyanjkk committed Sep 3, 2024
1 parent 1fdc504 commit a2a9b4a
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions examples/offline_inference_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
from vllm.sampling_params import SamplingParams


def generate_random_coding_question():
def generate_random_coding_question(output_len):
questions = [
"How do you reverse a string in Python?",
"Write a Python function to check if a number is prime.",
"Explain the difference between a list and a tuple in Python.",
"Write a Python script to merge two dictionaries.",
"What is the use of the 'with' statement in Python?",
"Write a Python program to find the factorial of a number using recursion.",
"How do you handle exceptions in Python?",
"Write a Python class to implement a basic calculator.",
"Explain the concept of decorators in Python.",
"Write a Python function to sort a list of tuples based on the second element."
"list all the imports available in python"
# "Write a Python function to check if a number is prime.",
# "Explain the difference between a list and a tuple in Python.",
# "Write a Python script to merge two dictionaries.",
# "What is the use of the 'with' statement in Python?",
# "Write a Python program to find the factorial of a number using recursion.",
# "How do you handle exceptions in Python?",
# "Write a Python class to implement a basic calculator.",
# "Explain the concept of decorators in Python.",
# "Write a Python function to sort a list of tuples based on the second element."
]
return random.choice(questions)

Expand All @@ -40,11 +40,25 @@ def str2bool(v):
else:
raise argparse.ArgumentTypeError('Boolean value expected.')

def get_conversation():
def get_conversation(num_words,output_len):
assistant_adjectives = ['an enthusiastic', 'a knowledgeable', 'a curious', 'a patient', 'an insightful', 'a clever']
assistants = ['coder', 'developer', 'programmer', 'software engineer', 'tech enthusiast', 'Python expert']

sys_message = ConversationMessage(role='system', content=f'You are {random.choice(assistant_adjectives)} {random.choice(assistants)}. You enjoy sharing your knowledge about Python programming.')
user_message = ConversationMessage(role='user', content=generate_random_coding_question())
user_message_content = []
word_count = 0

while word_count < num_words:
question = generate_random_coding_question(num_words)
words_in_question = question.split()
if word_count + len(words_in_question) > num_words:
# Add only the number of words needed to reach the exact word count
user_message_content.extend(words_in_question[:num_words - word_count])
else:
user_message_content.extend(words_in_question)
word_count = len(user_message_content)

user_message = ConversationMessage(role='user', content=' '.join(user_message_content))
return [sys_message, user_message]

def main():
Expand All @@ -65,6 +79,7 @@ def main():
parser.add_argument('--warmup', type=int, default=0, help='Number of warmup iterations to skip')
parser.add_argument('--fp8', type=str2bool, nargs='?', const=True, default=False, help='Boolean flag to enable fp8')
parser.add_argument('--measure', type=str2bool, nargs='?', const=True, default=False, help='Boolean flag to enable fp8 measurements')
parser.add_argument('--input-seq-len', type=int, default=256, help='Maximum input sequence length')

args = parser.parse_args()

Expand Down Expand Up @@ -108,18 +123,14 @@ def main():

chat_template = load_chat_template(provided_chat_template)
tokenizer = llm.llm_engine.get_tokenizer()
conversations = [get_conversation() for _ in range(batch_size)]
conversations = [get_conversation(args.input_seq_len,args.max_tokens) for _ in range(batch_size)]
prompts = [tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=True,
chat_template=chat_template,
) for conversation in conversations]

# Warmup iterations
for _ in range(warmup):
_ = llm.generate(prompts, sampling_params)

# Measure performance for the final iteration
start = time.time()
profile_ctx = contextlib.nullcontext()
Expand Down

0 comments on commit a2a9b4a

Please sign in to comment.