From dcea6815eaa20604f56ecf8102974db009d933dd Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Thu, 22 Jan 2026 19:06:56 +0800 Subject: [PATCH] issue/949 - feat: add silu_and_mul for moore gpu with test pass --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/silu_and_mul.hpp | 14 ++ include/infiniop.h | 1 + include/infiniop/ops/silu_and_mul.h | 29 +++ python/infinicore/nn/functional/__init__.py | 2 + .../infinicore/nn/functional/silu_and_mul.py | 17 ++ .../ops/silu_and_mul/silu_and_mul.cc | 35 ++++ .../ops/silu_and_mul/silu_and_mul_infiniop.cc | 50 +++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/silu_and_mul.hpp | 29 +++ src/infiniop/ops/silu_and_mul/info.h | 54 ++++++ .../silu_and_mul/moore/silu_and_mul_moore.h | 8 + .../silu_and_mul/moore/silu_and_mul_moore.mu | 123 ++++++++++++ src/infiniop/ops/silu_and_mul/operator.cc | 79 ++++++++ src/infiniop/ops/silu_and_mul/silu_and_mul.h | 46 +++++ test/infinicore/ops/silu_and_mul.py | 126 +++++++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++++ test/infiniop/silu_and_mul.py | 176 ++++++++++++++++++ 18 files changed, 824 insertions(+) create mode 100644 include/infinicore/ops/silu_and_mul.hpp create mode 100644 include/infiniop/ops/silu_and_mul.h create mode 100644 python/infinicore/nn/functional/silu_and_mul.py create mode 100644 src/infinicore/ops/silu_and_mul/silu_and_mul.cc create mode 100644 src/infinicore/ops/silu_and_mul/silu_and_mul_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/silu_and_mul.hpp create mode 100644 src/infiniop/ops/silu_and_mul/info.h create mode 100644 src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h create mode 100644 src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.mu create mode 100644 src/infiniop/ops/silu_and_mul/operator.cc create mode 100644 src/infiniop/ops/silu_and_mul/silu_and_mul.h create mode 100644 test/infinicore/ops/silu_and_mul.py create mode 100644 test/infiniop/silu_and_mul.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index a7249ec9d..b12560c8f 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -14,4 +14,5 @@ #include "ops/rms_norm.hpp" #include "ops/rope.hpp" #include "ops/silu.hpp" +#include "ops/silu_and_mul.hpp" #include "ops/swiglu.hpp" diff --git a/include/infinicore/ops/silu_and_mul.hpp b/include/infinicore/ops/silu_and_mul.hpp new file mode 100644 index 000000000..c8be68a99 --- /dev/null +++ b/include/infinicore/ops/silu_and_mul.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(SiluAndMul, Tensor, Tensor); + +Tensor silu_and_mul(Tensor x); +void silu_and_mul_(Tensor out, Tensor x); + +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..edc4753c6 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -26,6 +26,7 @@ #include "infiniop/ops/rope.h" #include "infiniop/ops/sigmoid.h" #include "infiniop/ops/silu.h" +#include "infiniop/ops/silu_and_mul.h" #include "infiniop/ops/softmax.h" #include "infiniop/ops/softplus.h" #include "infiniop/ops/sub.h" diff --git a/include/infiniop/ops/silu_and_mul.h b/include/infiniop/ops/silu_and_mul.h new file mode 100644 index 000000000..a07be6d54 --- /dev/null +++ b/include/infiniop/ops/silu_and_mul.h @@ -0,0 +1,29 @@ +#ifndef __INFINIOP_SILU_AND_MUL_API_H__ +#define __INFINIOP_SILU_AND_MUL_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSiluAndMulDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSiluAndMulDescriptor( + infiniopHandle_t handle, + infiniopSiluAndMulDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t input); + +__C __export infiniStatus_t infiniopGetSiluAndMulWorkspaceSize( + infiniopSiluAndMulDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopSiluAndMul( + infiniopSiluAndMulDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream); + +__C __export infiniStatus_t infiniopDestroySiluAndMulDescriptor( + infiniopSiluAndMulDescriptor_t desc); + +#endif // __INFINIOP_SILU_AND_MUL_API_H__ diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..c5005617d 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -5,6 +5,7 @@ from .rms_norm import rms_norm from .rope import RopeAlgo, rope from .silu import silu +from .silu_and_mul import silu_and_mul from .swiglu import swiglu __all__ = [ @@ -17,4 +18,5 @@ "embedding", "rope", "RopeAlgo", + "silu_and_mul", ] diff --git a/python/infinicore/nn/functional/silu_and_mul.py b/python/infinicore/nn/functional/silu_and_mul.py new file mode 100644 index 000000000..395754362 --- /dev/null +++ b/python/infinicore/nn/functional/silu_and_mul.py @@ -0,0 +1,17 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def silu_and_mul(input: Tensor, out=None) -> Tensor: + r"""Apply the SiLU and Mul (SwiGLU) function. + + Formula: output = SiLU(input_gate) * input_up + Input shape: [..., 2*d], Output shape: [..., d] + """ + + if out is None: + return Tensor(_infinicore.silu_and_mul(input._underlying)) + + _infinicore.silu_and_mul_(out._underlying, input._underlying) + + return out diff --git a/src/infinicore/ops/silu_and_mul/silu_and_mul.cc b/src/infinicore/ops/silu_and_mul/silu_and_mul.cc new file mode 100644 index 000000000..84735df40 --- /dev/null +++ b/src/infinicore/ops/silu_and_mul/silu_and_mul.cc @@ -0,0 +1,35 @@ +#include "infinicore/ops/silu_and_mul.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SiluAndMul); + +SiluAndMul::SiluAndMul(Tensor out, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, x); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, x); +} + +void SiluAndMul::execute(Tensor out, Tensor x) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(SiluAndMul, out, x); +} + +Tensor silu_and_mul(Tensor x) { + Shape shape = x->shape(); + size_t ndim = x->ndim(); + + if (shape[ndim - 1] % 2 != 0) { + throw std::runtime_error("SiluAndMul input last dim must be even."); + } + shape[ndim - 1] /= 2; + + auto out = Tensor::empty(shape, x->dtype(), x->device()); + silu_and_mul_(out, x); + return out; +} + +void silu_and_mul_(Tensor out, Tensor x) { + SiluAndMul::execute(out, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/silu_and_mul/silu_and_mul_infiniop.cc b/src/infinicore/ops/silu_and_mul/silu_and_mul_infiniop.cc new file mode 100644 index 000000000..e03a01c84 --- /dev/null +++ b/src/infinicore/ops/silu_and_mul/silu_and_mul_infiniop.cc @@ -0,0 +1,50 @@ +#include "../infiniop_impl.hpp" +#include "infinicore/ops/silu_and_mul.hpp" + +namespace infinicore::op::silu_and_mul_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SiluAndMul, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, output, input; +}; + +void *plan(Tensor output, Tensor input) { + size_t seed = hash_combine(output, input); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, SiluAndMul, + seed, output->desc(), input->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, SiluAndMul, descriptor); + + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(output), + graph::GraphTensor(input)}; + + return planned; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopSiluAndMul( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->output->data(), + planned->input->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SiluAndMul, &plan, &run, &cleanup); + +} // namespace infinicore::op::silu_and_mul_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 3d6ebe79a..71d7db8ae 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -18,6 +18,7 @@ #include "ops/rms_norm.hpp" #include "ops/rope.hpp" #include "ops/silu.hpp" +#include "ops/silu_and_mul.hpp" #include "ops/swiglu.hpp" namespace py = pybind11; @@ -42,6 +43,7 @@ inline void bind(py::module &m) { bind_swiglu(m); bind_rope(m); bind_embedding(m); + bind_silu_and_mul(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/silu_and_mul.hpp b/src/infinicore/pybind11/ops/silu_and_mul.hpp new file mode 100644 index 000000000..009c3b533 --- /dev/null +++ b/src/infinicore/pybind11/ops/silu_and_mul.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "infinicore/ops/silu_and_mul.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_silu_and_mul(py::module &m) { + m.def("silu_and_mul", + &op::silu_and_mul, + py::arg("input"), + R"doc( + SiLU and Mul (SwiGLU) activation function. + Input should be [..., 2*d], output will be [..., d]. + )doc"); + + m.def("silu_and_mul_", + &op::silu_and_mul_, + py::arg("output"), + py::arg("input"), + R"doc( + In-place or destination-specified SiLU and Mul (SwiGLU) activation function. + )doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/silu_and_mul/info.h b/src/infiniop/ops/silu_and_mul/info.h new file mode 100644 index 000000000..91ffeb1bf --- /dev/null +++ b/src/infiniop/ops/silu_and_mul/info.h @@ -0,0 +1,54 @@ +#ifndef __SILU_AND_MUL_INFO_H__ +#define __SILU_AND_MUL_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +namespace op::silu_and_mul { + +class SiluAndMulInfo { + SiluAndMulInfo() = default; + +public: + infiniDtype_t dtype; + size_t batch_size; + size_t out_hidden_dim; + + static utils::Result create(infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) { + auto dtype = y_desc->dtype(); + + auto x_shape = x_desc->shape(); + auto y_shape = y_desc->shape(); + auto ndim = x_desc->ndim(); + + if (ndim != y_desc->ndim()) { + return INFINI_STATUS_BAD_PARAM; + } + + if (x_shape[ndim - 1] != 2 * y_shape[ndim - 1]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + size_t batch = 1; + for (int i = 0; i < (int)ndim - 1; ++i) { + if (x_shape[i] != y_shape[i]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + batch *= y_shape[i]; + } + + return utils::Result(SiluAndMulInfo{ + dtype, + batch, + y_shape[ndim - 1]}); + } + +private: + SiluAndMulInfo(infiniDtype_t dtype, size_t batch, size_t hidden) + : dtype(dtype), batch_size(batch), out_hidden_dim(hidden) {} +}; + +} // namespace op::silu_and_mul + +#endif // __SILU_AND_MUL_INFO_H__ diff --git a/src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h b/src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h new file mode 100644 index 000000000..2789e2de6 --- /dev/null +++ b/src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h @@ -0,0 +1,8 @@ +#ifndef __SILU_ADN_MUL_MOORE_API_H__ +#define __SILU_ADN_MUL_MOORE_API_H__ + +#include "../silu_and_mul.h" + +DESCRIPTOR(moore) + +#endif // __SILU_ADN_MUL_MOORE_API_H__ diff --git a/src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.mu b/src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.mu new file mode 100644 index 000000000..48fcb9609 --- /dev/null +++ b/src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.mu @@ -0,0 +1,123 @@ +#include "../../../devices/moore/moore_common.h" +#include "../../../devices/moore/moore_handle.h" +#include "silu_and_mul_moore.h" + +#include +#include + +namespace op::silu_and_mul::moore { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + + if (!desc_ptr) { + return INFINI_STATUS_BAD_PARAM; + } + + auto handle = reinterpret_cast(handle_); + auto dtype = y_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + if (x_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto result = SiluAndMulInfo::create(y_desc, x_desc); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{handle->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculate_impl( + const SiluAndMulInfo &info, + std::shared_ptr &internal, + void *y, + const void *x, + void *stream) { + + return internal->useMudnn( + (musaStream_t)stream, + [&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t { + + ::musa::dnn::Tensor x_t, y_t; + + if constexpr (std::is_same_v) { + x_t.SetType(::musa::dnn::Tensor::Type::HALF); + y_t.SetType(::musa::dnn::Tensor::Type::HALF); + } else if constexpr (std::is_same_v) { + x_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + y_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + } else { + x_t.SetType(::musa::dnn::Tensor::Type::FLOAT); + y_t.SetType(::musa::dnn::Tensor::Type::FLOAT); + } + + x_t.SetAddr(const_cast(x)); + y_t.SetAddr(y); + + // --- Construct 2D dimension information --- + // Explicitly distinguish between Batch and Hidden dimensions + int64_t b = static_cast(info.batch_size); + int64_t h = static_cast(info.out_hidden_dim); + + // Input x logical shape is [batch, 2 * hidden] + std::array x_dims = {b, h * 2}; + std::array x_strides = {h * 2, 1}; + + // Output y logical shape is [batch, hidden] + std::array y_dims = {b, h}; + std::array y_strides = {h, 1}; + + x_t.SetNdInfo(2, x_dims.data(), x_strides.data()); + y_t.SetNdInfo(2, y_dims.data(), y_strides.data()); + + // Invoke muDNN SwiGLU + // muDNN will split each row (length 2*h) internally, + // muDNN treats the first h elements of input x as the 'gate' + // and the following h elements as the 'up' projection. + ::musa::dnn::SwiGlu swiglu; + swiglu.Run(mudnn_handle, y_t, x_t); + + return INFINI_STATUS_SUCCESS; + }); +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + infiniDtype_t dtype = _info.dtype; + + switch (dtype) { + case INFINI_DTYPE_F16: + return calculate_impl(_info, _opaque->internal, y, x, stream); + case INFINI_DTYPE_F32: + return calculate_impl(_info, _opaque->internal, y, x, stream); + case INFINI_DTYPE_BF16: + return calculate_impl<__mt_bfloat16>(_info, _opaque->internal, y, x, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::silu_and_mul::moore diff --git a/src/infiniop/ops/silu_and_mul/operator.cc b/src/infiniop/ops/silu_and_mul/operator.cc new file mode 100644 index 000000000..5cf73c93d --- /dev/null +++ b/src/infiniop/ops/silu_and_mul/operator.cc @@ -0,0 +1,79 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/silu_and_mul.h" + +#ifdef ENABLE_MOORE_API +#include "moore/silu_and_mul_moore.h" +#endif + +__C infiniStatus_t infiniopCreateSiluAndMulDescriptor( + infiniopHandle_t handle, + infiniopSiluAndMulDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::silu_and_mul::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + x_desc); + + switch (handle->device) { +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopSiluAndMul( + infiniopSiluAndMulDescriptor_t desc, + void *workspace, size_t workspace_size, + void *y, + const void *x, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, y, x, stream); + + switch (desc->device_type) { +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} diff --git a/src/infiniop/ops/silu_and_mul/silu_and_mul.h b/src/infiniop/ops/silu_and_mul/silu_and_mul.h new file mode 100644 index 000000000..ced75e68d --- /dev/null +++ b/src/infiniop/ops/silu_and_mul/silu_and_mul.h @@ -0,0 +1,46 @@ +#ifndef SILU_AND_MUL_H +#define SILU_AND_MUL_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::silu_and_mul::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + SiluAndMulInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + SiluAndMulInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream) const; \ + }; \ + } + +#endif // SILU_AND_MUL_H diff --git a/test/infinicore/ops/silu_and_mul.py b/test/infinicore/ops/silu_and_mul.py new file mode 100644 index 000000000..5e3feb18e --- /dev/null +++ b/test/infinicore/ops/silu_and_mul.py @@ -0,0 +1,126 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TestCase, + GenericTestRunner, + is_broadcast, +) + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (input_shape) +# The operator splits the last dimension: Input (..., 2*d) -> Output (..., d) +_TEST_CASES_DATA = [ + (2, 4), + (1024, 1024), + (2, 4, 8), + (1, 22016), + (2, 4, 256), + (2, 4, 16, 256), +] + +# Tolerance configuration for different precisions +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, + infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2}, +} + +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + """ + Parse SiLUAndMul test case data. + Input shape: [..., 2*d], Output shape: [..., d] + Note: In-place is not supported due to shape mismatch between input and output. + """ + test_cases = [] + + for input_shape in _TEST_CASES_DATA: + # Calculate output shape based on SwiGLU logic + output_shape = list(input_shape) + output_shape[-1] //= 2 + output_shape = tuple(output_shape) + + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) + + input_spec = TensorSpec.from_tensor(input_shape, None, dtype) + output_spec = TensorSpec.from_tensor(output_shape, None, dtype) + + # Case 1: Functional style (allocates new memory for output) + test_cases.append( + TestCase( + inputs=[input_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description=f"SiLUAndMul_Functional_{dtype}", + ) + ) + + # Case 2: Explicit output tensor style (uses pre-allocated buffer) + test_cases.append( + TestCase( + inputs=[input_spec], + kwargs=None, + output_spec=output_spec, + comparison_target="out", + tolerance=tolerance, + description=f"SiLUAndMul_OutParam_{dtype}", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """SiLUAndMul operator test (SwiGLU activation)""" + + def __init__(self): + super().__init__("SiLUAndMul") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, input, out=None, **kwargs): + """ + PyTorch SwiGLU reference implementation: + Formula: SiLU(gate) * up, where [gate, up] = split(input) + """ + d = input.shape[-1] // 2 + # Split the last dimension into two equal parts + gate, up = torch.split(input, [d, d], dim=-1) + result = torch.nn.functional.silu(gate) * up + + if out is not None: + out.copy_(result) + return out + return result + + def infinicore_operator(self, input, out=None, **kwargs): + """InfiniCore SiLUAndMul implementation wrapper""" + import infinicore.nn.functional as F + + return F.silu_and_mul(input, out=out) + + +def main(): + """Main entry point for the test runner""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 618be2b05..8ebbaed54 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1144,3 +1144,35 @@ def paged_attention_prefill_(lib): lib.infiniopDestroyPagedAttentionPrefillDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def silu_and_mul(lib): + lib.infiniopCreateSiluAndMulDescriptor.restype = c_int32 + lib.infiniopCreateSiluAndMulDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetSiluAndMulWorkspaceSize.restype = c_int32 + lib.infiniopGetSiluAndMulWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSiluAndMul.restype = c_int32 + lib.infiniopSiluAndMul.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySiluAndMulDescriptor.restype = c_int32 + lib.infiniopDestroySiluAndMulDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/silu_and_mul.py b/test/infiniop/silu_and_mul.py new file mode 100644 index 000000000..b9117f55f --- /dev/null +++ b/test/infiniop/silu_and_mul.py @@ -0,0 +1,176 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# Format: (input_shape, output_shape) +# Referencing vLLM kernel Silu_and_Mul interface: +# input_shape is [..., 2*d], output_shape is [..., d] +_TEST_CASES = [ + # input_shape, output_shape + ((2, 8), (2, 4)), + ((1024, 1024), (1024, 512)), + ((16, 8192), (16, 4096)), + ((2, 128, 2048), (2, 128, 1024)), + ((8, 1, 4096), (8, 1, 2048)), + ((2, 4, 16, 256), (2, 4, 16, 128)), +] + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +# PyTorch reference: silu(gate) * up where [gate, up] = split(input) +def silu_and_mul_torch(out, input_tensor): + """ + Computes the SwiGLU activation function: SiLU(gate) * up. + """ + # Split the last dimension into two halves: + # the first half is 'gate', the second is 'up' + d = input_tensor.shape[-1] // 2 + gate = input_tensor[..., :d] + up = input_tensor[..., d:] + + # Apply SiLU to the gate and multiply by the up projection + torch.mul(torch.nn.functional.silu(gate), up, out=out) + + +# ============================================================================== +# Test Logic +# ============================================================================== +def test( + handle, + device, + input_shape, + output_shape, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing SiluAndMul on {InfiniDeviceNames[device]} with " + f"input_shape:{input_shape} output_shape:{output_shape} dtype:{InfiniDtypeNames[dtype]}" + ) + + a = TestTensor(input_shape, None, dtype, device) + c = TestTensor(output_shape, None, dtype, device, mode="zeros") + ans = TestTensor(output_shape, None, dtype, device, mode="zeros") + + # Only support contiguous Tensor + if not ( + a.torch_tensor().is_contiguous() + and c.torch_tensor().is_contiguous() + and ans.torch_tensor().is_contiguous() + ): + raise ValueError("This operator only supports contiguous memory layout.") + + # PyTorch answer reference + def torch_silu_and_mul_reference(): + silu_and_mul_torch(ans.torch_tensor(), a.torch_tensor()) + + torch_silu_and_mul_reference() + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateSiluAndMulDescriptor( + handle, + ctypes.byref(descriptor), + c.descriptor, + a.descriptor, + ) + ) + + for tensor in [a, c]: + tensor.destroy_desc() + + # Workspace + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetSiluAndMulWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_op(): + check_error( + LIBINFINIOP.infiniopSiluAndMul( + descriptor, + workspace.data(), + workspace_size.value, + c.data(), + a.data(), + None, + ) + ) + + lib_op() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch_silu_and_mul_reference(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_op(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroySiluAndMulDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mSiluAndMul Test passed!\033[0m")