From c5353b841217f75872370fa9a54bc1f4ad1e9a40 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 24 Jan 2026 11:38:42 +0800 Subject: [PATCH] change fp4 init range --- problems/nvidia/nvfp4_group_gemm/reference.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py index f12f504..7ce6bc0 100644 --- a/problems/nvidia/nvfp4_group_gemm/reference.py +++ b/problems/nvidia/nvfp4_group_gemm/reference.py @@ -123,6 +123,17 @@ def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): return reordered_f8_tensor +def _create_fp4_tensors(l, mn, k): + # generate uint8 tensor, then convert to float4e2m1fn_x2 data type + # generate all bit patterns + ref_i8 = torch.randint(255, size=(l, mn, k // 2), dtype=torch.uint8, device="cuda") + + # for each nibble, only keep the sign bit and 2 LSBs + # the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5] + ref_i8 = ref_i8 & 0b1011_1011 + return ref_i8.permute(1, 2, 0).view(torch.float4_e2m1fn_x2) + + def generate_input( m: tuple, n: tuple, @@ -165,14 +176,8 @@ def generate_input( mi = m[group_idx] ni = n[group_idx] ki = k[group_idx] - a_ref = torch.randint( - -1, 2, (l, mi, ki // 2), dtype=torch.int8, device="cuda" - ).permute(1, 2, 0) - b_ref = torch.randint( - -1, 2, (l, ni, ki // 2), dtype=torch.int8, device="cuda" - ).permute(1, 2, 0) - a_ref = a_ref.view(torch.float4_e2m1fn_x2) - b_ref = b_ref.view(torch.float4_e2m1fn_x2) + a_ref = _create_fp4_tensors(l, mi, ki) + b_ref = _create_fp4_tensors(l, ni, ki) c_ref = torch.randn((l, mi, ni), dtype=torch.float16, device="cuda").permute( 1, 2, 0