Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RayTaskError #77

Closed
Raf-Chen opened this issue Jan 4, 2025 · 2 comments
Closed

RayTaskError #77

Raf-Chen opened this issue Jan 4, 2025 · 2 comments
Labels
bug Something isn't working vllm related

Comments

@Raf-Chen
Copy link

Raf-Chen commented Jan 4, 2025

Thanks for the lib! When I run the script run_qwen2-7b_rm.sh, a RayTaskError occurs. I have checked that the Ray version is correct (2.10), but this error is not sporadic—it happens every time. Could you please tell me what might be causing this issue?

Error executing job with overrides: ["data.train_files=['/cpfs01/user/zhangyuchen/data/gsm8k/train.parquet', '/cpfs01/user/zhangyuchen/data/math/train.parquet']", "data.val_files=['/cpfs01/user/zhangyuchen/data/gsm8k/test.parquet', '/cpfs01/user/zhangyuchen/data/math/test.parquet']", 'data.train_batch_size=256', 'data.val_batch_size=256', 'data.max_prompt_length=1024', 'data.max_response_length=1024', 'data.return_raw_chat=True', 'actor_rollout_ref.model.path=/cpfs01/user/zhangyuchen/hf-llms/Qwen/Qwen2.5-7B-Instruct', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1', 'actor_rollout_ref.actor.ppo_mini_batch_size=256', 'actor_rollout_ref.actor.ppo_micro_batch_size=8', 'actor_rollout_ref.actor.fsdp_config.param_offload=False', 'actor_rollout_ref.actor.fsdp_config.grad_offload=False', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=False', 'actor_rollout_ref.rollout.log_prob_micro_batch_size=16', 'actor_rollout_ref.rollout.tensor_model_parallel_size=1', 'actor_rollout_ref.rollout.name=vllm', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.6', 'actor_rollout_ref.ref.log_prob_micro_batch_size=16', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'critic.optim.lr=1e-5', 'critic.optim.lr_warmup_steps_ratio=0.05', 'critic.model.path=/cpfs01/user/zhangyuchen/hf-llms/Qwen/Qwen2.5-7B-Instruct', 'critic.model.enable_gradient_checkpointing=False', 'critic.ppo_micro_batch_size=8', 'critic.model.fsdp_config.param_offload=False', 'critic.model.fsdp_config.grad_offload=False', 'critic.model.fsdp_config.optimizer_offload=False', 'reward_model.enable=True', 'reward_model.model.path=/cpfs01/user/zhangyuchen/hf-llms/Fsfair/FsfairX-Gemma2-RM-v0.1', 'reward_model.model.fsdp_config.param_offload=True', 'reward_model.micro_batch_size=8', 'algorithm.kl_ctrl.kl_coef=0.001', 'trainer.critic_warmup=0', 'trainer.logger=[console,wandb]', 'trainer.project_name=verl_example', 'trainer.experiment_name=Qwen2-7B-ppo_orm', 'trainer.n_gpus_per_node=8', 'trainer.nnodes=1', 'trainer.save_freq=-1', 'trainer.test_freq=5', 'trainer.total_epochs=15']
Traceback (most recent call last):
  File "/cpfs01/user/zhangyuchen/verl/verl/trainer/main_ppo.py", line 101, in main
    ray.get(main_task.remote(config))
  File "/usr/local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/ray/_private/worker.py", line 2667, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/site-packages/ray/_private/worker.py", line 864, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::main_task() (pid=4579, ip=10.1.10.42)
  File "/cpfs01/user/zhangyuchen/verl/verl/trainer/main_ppo.py", line 187, in main_task
    trainer.fit()
  File "/cpfs01/user/zhangyuchen/verl/verl/trainer/ppo/ray_trainer.py", line 550, in fit
    gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
  File "/cpfs01/user/zhangyuchen/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=6677, ip=10.1.10.42, actor_id=ac4e0f92172811787d45948a01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f73e4268ca0>)
  File "/usr/local/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1708, in execute_model
    output: SamplerOutput = self.model.sample(
  File "/usr/local/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 433, in sample
    next_tokens = self.sampler(logits, sampling_metadata)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 231, in forward
    self._init_sampling_tensors(logits, sampling_metadata)
  File "/usr/local/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 195, in _init_sampling_tensors
    do_min_p) = SamplingTensors.from_sampling_metadata(
  File "/usr/local/lib/python3.10/site-packages/vllm/model_executor/sampling_metadata.py", line 471, in from_sampling_metadata
    sampling_tensors = SamplingTensors.from_lists(
  File "/usr/local/lib/python3.10/site-packages/vllm/model_executor/sampling_metadata.py", line 529, in from_lists
    temperatures_t = torch.tensor(
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


The above exception was the direct cause of the following exception:

ray::WorkerDict.actor_rollout_generate_sequences() (pid=6677, ip=10.1.10.42, actor_id=ac4e0f92172811787d45948a01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f73e4268ca0>)
  File "/cpfs01/user/zhangyuchen/verl/verl/workers/fsdp_workers.py", line 363, in generate_sequences
    output = self.rollout.generate_sequences(prompts=prompts)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/cpfs01/user/zhangyuchen/verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py", line 174, in generate_sequences
    output = self.inference_engine.generate(
  File "/usr/local/lib/python3.10/site-packages/vllm/utils.py", line 1063, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 353, in generate
    outputs = self._run_engine(use_tqdm=use_tqdm)
  File "/cpfs01/user/zhangyuchen/verl/verl/third_party/vllm/vllm_v_0_6_3/llm.py", line 161, in _run_engine
    outputs = super()._run_engine(use_tqdm=use_tqdm)
  File "/usr/local/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 879, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/usr/local/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1386, in step
    outputs = self.model_executor.execute_model(
  File "/cpfs01/user/zhangyuchen/verl/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py", line 163, in execute_model
    all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
  File "/cpfs01/user/zhangyuchen/verl/verl/third_party/vllm/vllm_v_0_6_3/worker.py", line 267, in execute_model
    return self.model_runner.execute_model(
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/vllm/worker/model_runner_base.py", line 146, in _wrapper
    raise type(err)(f"Error in model execution: "
RuntimeError: Error in model execution: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



"""
During handling of the above exception, another exception occurred:

ray::WorkerDict.actor_rollout_generate_sequences() (pid=6677, ip=10.1.10.42, actor_id=ac4e0f92172811787d45948a01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f73e4268ca0>)
  File "/cpfs01/user/zhangyuchen/verl/verl/single_controller/ray/base.py", line 399, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/cpfs01/user/zhangyuchen/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/cpfs01/user/zhangyuchen/verl/verl/workers/fsdp_workers.py", line 359, in generate_sequences
    with self.sharding_manager:
  File "/cpfs01/user/zhangyuchen/verl/verl/workers/hybrid_engine/fsdp_vllm.py", line 82, in __exit__
    torch.cuda.synchronize() 
  File "/usr/local/lib/python3.10/site-packages/torch/cuda/__init__.py", line 892, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
@PeterSH6
Copy link
Collaborator

PeterSH6 commented Jan 4, 2025

Hi @Raf-Chen, thanks for your report!

This bug may be related to FlashAttention Backend in vLLM. Can you try it using a different backend of vLLM export VLLM_ATTENTION_BACKEND=XFORMERS
Similar issue: #12

It's quite a common issue in vLLM. I'll check if vLLM has some other workarounds with FlashAttention.

@PeterSH6 PeterSH6 added bug Something isn't working vllm related labels Jan 4, 2025
@Raf-Chen
Copy link
Author

Raf-Chen commented Jan 7, 2025

Thanks for the response!

@Raf-Chen Raf-Chen closed this as completed Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working vllm related
Projects
None yet
Development

No branches or pull requests

2 participants