[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (#7210)
This commit is contained in:
@ -9,6 +9,18 @@
|
||||
|
||||
#include "../../reduction_utils.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include "amd/hip_float8.h"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
return old;
|
||||
}
|
||||
|
||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
float const val, float const scale) {
|
||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale,
|
||||
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
|
||||
};
|
||||
|
||||
typedef struct __align__(4) {
|
||||
c10::Float8_e4m3fn x;
|
||||
c10::Float8_e4m3fn y;
|
||||
c10::Float8_e4m3fn z;
|
||||
c10::Float8_e4m3fn w;
|
||||
FP8_TYPE x;
|
||||
FP8_TYPE y;
|
||||
FP8_TYPE z;
|
||||
FP8_TYPE w;
|
||||
}
|
||||
float8x4_t;
|
||||
|
||||
@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
||||
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
||||
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
||||
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size);
|
||||
|
||||
Reference in New Issue
Block a user