Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion skyrl/backends/skyrl_train/inference_servers/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def load_weights(self, request: bytes) -> None:
"""
import pickle

from vllm.config import set_current_vllm_config

# Unpickle request to restore the original object type
assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}"
request = pickle.loads(request)
Expand All @@ -90,7 +92,8 @@ def load_weights(self, request: bytes) -> None:
for name, tensor in self._weight_receiver.receive_weights(request):
weight_list.append((name, tensor))

self.model_runner.model.load_weights(weights=weight_list)
with torch.device(self.device), set_current_vllm_config(self.vllm_config):
self.model_runner.reload_weights(weights_iterator=iter(weight_list))
Comment thread
erictang000 marked this conversation as resolved.

for weight in weight_list:
del weight
Expand Down
Loading