diff --git a/lib/cuda/total_variation_kernel.cu b/lib/cuda/total_variation_kernel.cu index 72fab09..7a49795 100644 --- a/lib/cuda/total_variation_kernel.cu +++ b/lib/cuda/total_variation_kernel.cu @@ -28,8 +28,8 @@ __global__ void total_variation_add_grad_cuda_kernel( grad_to_add += (k==sz_k-1 ? 0 : wz * clamp(param[index]-param[index+1], -1.f, 1.f)); grad_to_add += (j==0 ? 0 : wy * clamp(param[index]-param[index-sz_k], -1.f, 1.f)); grad_to_add += (j==sz_j-1 ? 0 : wy * clamp(param[index]-param[index+sz_k], -1.f, 1.f)); - grad_to_add += (i==0 ? 0 : wz * clamp(param[index]-param[index-sz_k*sz_j], -1.f, 1.f)); - grad_to_add += (i==sz_i-1 ? 0 : wz * clamp(param[index]-param[index+sz_k*sz_j], -1.f, 1.f)); + grad_to_add += (i==0 ? 0 : wx * clamp(param[index]-param[index-sz_k*sz_j], -1.f, 1.f)); + grad_to_add += (i==sz_i-1 ? 0 : wx * clamp(param[index]-param[index+sz_k*sz_j], -1.f, 1.f)); grad[index] += grad_to_add; } }