Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions problems/nvidia/nvfp4_group_gemm/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor):
return reordered_f8_tensor


def _create_fp4_tensors(l, mn, k):
# generate uint8 tensor, then convert to float4e2m1fn_x2 data type
# generate all bit patterns
ref_i8 = torch.randint(255, size=(l, mn, k // 2), dtype=torch.uint8, device="cuda")

# for each nibble, only keep the sign bit and 2 LSBs
# the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5]
ref_i8 = ref_i8 & 0b1011_1011
return ref_i8.permute(1, 2, 0).view(torch.float4_e2m1fn_x2)


def generate_input(
m: tuple,
n: tuple,
Expand Down Expand Up @@ -165,14 +176,8 @@ def generate_input(
mi = m[group_idx]
ni = n[group_idx]
ki = k[group_idx]
a_ref = torch.randint(
-1, 2, (l, mi, ki // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
b_ref = torch.randint(
-1, 2, (l, ni, ki // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
a_ref = a_ref.view(torch.float4_e2m1fn_x2)
b_ref = b_ref.view(torch.float4_e2m1fn_x2)
a_ref = _create_fp4_tensors(l, mi, ki)
b_ref = _create_fp4_tensors(l, ni, ki)

c_ref = torch.randn((l, mi, ni), dtype=torch.float16, device="cuda").permute(
1, 2, 0
Expand Down