Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions profile_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAGraph.h>
#include "pufferlib/extensions/cuda/kernels.cu"
#include "pufferlib/extensions/modules.cu"

// models.cpp provides Policy, MinGRU, DefaultEncoder, DefaultDecoder,
Expand Down Expand Up @@ -1003,6 +1002,7 @@ typedef struct {
float* adv_mean;
float* adv_var; // variance, kernel does sqrt
float* loss;
float* losses_acc; // (LOSS_N,) accumulator for loss components
double* saved_for_backward;
precision_t* ratio_out;
precision_t* newvalue_out;
Expand Down Expand Up @@ -1046,6 +1046,8 @@ PPOLossArgs* create_ppolossargs(int batch, int seq, int actions) {
cudaMalloc(&args->adv_mean, sizeof(float));
cudaMalloc(&args->adv_var, sizeof(float));
cudaMalloc(&args->loss, sizeof(float));
cudaMalloc(&args->losses_acc, LOSS_N * sizeof(float));
cudaMemset(args->losses_acc, 0, LOSS_N * sizeof(float));
cudaMalloc(&args->saved_for_backward, NT * 5 * sizeof(double));
cudaMalloc(&args->ratio_out, NT * sizeof(precision_t));
cudaMalloc(&args->newvalue_out, NT * sizeof(precision_t));
Expand Down Expand Up @@ -1132,6 +1134,7 @@ void free_ppolossargs(PPOLossArgs* args) {
cudaFree(args->adv_mean);
cudaFree(args->adv_var);
cudaFree(args->loss);
cudaFree(args->losses_acc);
cudaFree(args->saved_for_backward);
cudaFree(args->ratio_out);
cudaFree(args->newvalue_out);
Expand All @@ -1147,7 +1150,8 @@ void run_ppoloss_forward(PPOLossArgs* args) {
int ppo_grid = (total + PPO_THREADS - 1) / PPO_THREADS;
cudaMemset(args->loss, 0, sizeof(float));
ppo_loss_forward_kernel_optimized<<<ppo_grid, PPO_THREADS>>>(
args->loss, args->saved_for_backward,
args->loss, args->losses_acc,
args->saved_for_backward,
args->ratio_out, args->newvalue_out,
args->logits,
nullptr, // logstd (nullptr for discrete)
Expand Down Expand Up @@ -1198,6 +1202,7 @@ typedef struct {
torch::Tensor ratio_out;
torch::Tensor newvalue_out;
torch::Tensor act_sizes;
torch::Tensor losses_acc;
torch::Tensor loss;
float clip_coef;
float vf_clip_coef;
Expand Down Expand Up @@ -1233,6 +1238,7 @@ PPOLossArgsTorch* create_ppolossargs_torch(PPOLossArgs* raw) {
args->ratio_out = torch::zeros({raw->N, raw->T}, opts);
args->newvalue_out = torch::zeros({raw->N, raw->T}, opts);
args->act_sizes = torch::tensor({raw->A}, cuda_i32);
args->losses_acc = torch::zeros({NUM_LOSSES}, cuda_f32);

return args;
}
Expand All @@ -1247,6 +1253,7 @@ torch::Tensor run_fused_ppo_forward(PPOLossArgsTorch* args) {
args->old_logprobs, args->advantages, args->prio,
args->values, args->returns, args->adv_mean, args->adv_var,
args->ratio_out, args->newvalue_out, args->act_sizes,
args->losses_acc,
args->clip_coef, args->vf_clip_coef, args->vf_coef, args->ent_coef)[0];
}

Expand All @@ -1271,7 +1278,7 @@ TrainLossResult compute_test_loss(PPOLossArgsTorch* args, bool use_kernels) {
args->act_sizes, torch::tensor({(int64_t)A}, torch::dtype(torch::kInt64)),
N * T, T,
args->clip_coef, args->vf_clip_coef, args->vf_coef, args->ent_coef,
/*is_continuous=*/false, use_kernels);
/*is_continuous=*/false, use_kernels, args->losses_acc);
return {loss, logits, values_pred};
}

Expand Down Expand Up @@ -2307,6 +2314,7 @@ typedef struct {
Tensor newvalue_out; // (N, T) - side-effect output
Tensor act_sizes; // (1,) int32 cuda
Tensor act_sizes_cpu; // (1,) int64 cpu
Tensor losses; // (NUM_LOSSES,) float32 accumulator

// Muon optimizer (matches production — NOT Adam)
std::shared_ptr<torch::optim::Muon> muon;
Expand Down Expand Up @@ -2382,6 +2390,7 @@ TrainArgs* create_trainargs(int N, int T_seq, int input_size, int hidden, int ac
args->newvalue_out = torch::zeros({N, T_seq}, opts);
args->act_sizes = torch::tensor({act_n}, cuda_i32);
args->act_sizes_cpu = torch::tensor({(int64_t)act_n}, torch::dtype(torch::kInt64));
args->losses = torch::zeros({NUM_LOSSES}, cuda_f32);

return args;
}
Expand All @@ -2402,7 +2411,7 @@ Tensor compute_loss_impl(TrainArgs* args, Logits& raw_logits, Tensor& newvalue)
args->values, args->returns, args->ratio_out, args->newvalue_out,
args->act_sizes, args->act_sizes_cpu, mb, args->T_seq,
args->clip_coef, args->vf_clip_coef, args->vf_coef, args->ent_coef,
/*is_continuous=*/false, args->use_kernels);
/*is_continuous=*/false, args->use_kernels, args->losses);
}

// Run functions for individual phases
Expand Down
4 changes: 2 additions & 2 deletions pufferlib/extensions/cuda/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ __global__ void select_copy_kernel(
const precision_t* __restrict__ src_values, precision_t* __restrict__ dst_values,
const float* __restrict__ src_advantages, float* __restrict__ dst_advantages,
precision_t* __restrict__ dst_returns, int horizon,
const precision_t* __restrict__ src_prio, precision_t* __restrict__ dst_prio
const float* __restrict__ src_prio, precision_t* __restrict__ dst_prio
) {
int mb = blockIdx.x;
int ch = blockIdx.y;
Expand All @@ -1120,7 +1120,7 @@ __global__ void select_copy_kernel(
break;
case 4:
if (threadIdx.x == 0) {
dst_prio[mb] = src_prio[mb];
dst_prio[mb] = from_float(src_prio[mb]);
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion pufferlib/extensions/modules.cu
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ void train_select_and_copy_cuda(
(const precision_t*)values.data_ptr(), (precision_t*)dst_values.data_ptr(),
advantages.data_ptr<float>(), dst_advantages.data_ptr<float>(),
(precision_t*)dst_returns.data_ptr(), horizon,
(const precision_t*)mb_prio.data_ptr(), (precision_t*)dst_prio.data_ptr());
mb_prio.data_ptr<float>(), (precision_t*)dst_prio.data_ptr());
}

// Host dispatch: replaces ~9 PyTorch kernel launches with 3 custom + multinomial
Expand Down
5 changes: 2 additions & 3 deletions pufferlib/extensions/pufferlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ void train_impl(PuffeRL& pufferl) {
auto out = compute_prio_cuda(advantages, prio_alpha, minibatch_segments,
hypers.total_agents, anneal_beta);
idx = std::get<0>(out);
mb_prio = std::get<1>(out);
mb_prio = std::get<1>(out); // always fp32 (comes from advantages)
}
else {
Tensor adv = advantages.abs().sum(1);
Expand All @@ -436,8 +436,7 @@ void train_impl(PuffeRL& pufferl) {
profile_end(hypers.profile);

profile_begin("train_select_and_copy", hypers.profile);
// Broken kernel
if (false && hypers.kernels) {
if (hypers.kernels) {
train_select_and_copy_cuda(
rollouts.observations, rollouts.actions, rollouts.logprobs,
rollouts.values, advantages,
Expand Down
Loading