Skip to content

Commit

Permalink
ci: Fix L0_model_control_stress_vllm (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh authored Dec 23, 2024
1 parent 2f5bfbd commit d061556
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions samples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@

class LLMClient:
def __init__(self, flags: argparse.Namespace):
self._client = grpcclient.InferenceServerClient(
url=flags.url, verbose=flags.verbose
)
self._flags = flags
self._loop = asyncio.get_event_loop()
self._results_dict = {}

def get_triton_client(self):
try:
triton_client = grpcclient.InferenceServerClient(
url=self._flags.url,
verbose=self._flags.verbose,
)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()

return triton_client

async def async_request_iterator(
self, prompts, sampling_parameters, exclude_input_in_output
):
Expand All @@ -65,8 +73,9 @@ async def async_request_iterator(

async def stream_infer(self, prompts, sampling_parameters, exclude_input_in_output):
try:
triton_client = self.get_triton_client()
# Start streaming
response_iterator = self._client.stream_infer(
response_iterator = triton_client.stream_infer(
inputs_iterator=self.async_request_iterator(
prompts, sampling_parameters, exclude_input_in_output
),
Expand Down Expand Up @@ -138,7 +147,7 @@ async def run(self):
print("FAIL: vLLM example")

def run_async(self):
self._loop.run_until_complete(self.run())
asyncio.run(self.run())

def create_request(
self,
Expand Down

0 comments on commit d061556

Please sign in to comment.