diff --git a/include/infinicore/graph/graph.hpp b/include/infinicore/graph/graph.hpp index c63b3272d..d997e0224 100644 --- a/include/infinicore/graph/graph.hpp +++ b/include/infinicore/graph/graph.hpp @@ -30,17 +30,21 @@ class GraphOperator { class Graph { public: - Graph() = default; - ~Graph() = default; + Graph(); + ~Graph(); void run() const; protected: void add_operator(std::shared_ptr op); - + void instantiate(); std::vector> op_list_; friend class GraphManager; + +private: + struct DeviceGraph; + std::unique_ptr device_graph_; }; } // namespace infinicore::graph diff --git a/include/infinicore/tensor.hpp b/include/infinicore/tensor.hpp index e9f210186..e9a62eb72 100644 --- a/include/infinicore/tensor.hpp +++ b/include/infinicore/tensor.hpp @@ -133,7 +133,18 @@ class TensorImpl : public std::enable_shared_from_this { void debug() const; - Tensor to_blob() const; + /** + * Unsafe API that returns a new tensor with the same raw memory untracked by allocator + * This API is used for loosely tracking a piece of memory while allowing it to be reused, + * typically in a compute graph scenario. + */ + Tensor to_blob_() const; + + /** + * Unsafe API that returns a new tensor with the same memory and let allocator retracks the memory. + * Should only be used on the tensor returned by to_blob_(). + */ + Tensor resume_from_blob_() const; /// /// Data Transfer APIs @@ -299,6 +310,10 @@ class TensorImpl : public std::enable_shared_from_this { protected: TensorMetaData meta_; TensorData data_; + +private: + // Mark to indicate if the tensor is created from to_blob_() + bool to_blob_mark_ = false; }; } // namespace infinicore diff --git a/include/infiniop/ops/swiglu_cuda.h b/include/infiniop/ops/swiglu_cuda.h new file mode 100644 index 000000000..bebce0dc8 --- /dev/null +++ b/include/infiniop/ops/swiglu_cuda.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_SWIGLU_CUDA_API_H__ +#define __INFINIOP_SWIGLU_CUDA_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSwiGLUCudaDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSwiGLUCudaDescriptor(infiniopHandle_t handle, + infiniopSwiGLUCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + +__C __export infiniStatus_t infiniopGetSwiGLUCudaWorkspaceSize(infiniopSwiGLUCudaDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopSwiGLUCuda(infiniopSwiGLUCudaDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream); + +__C __export infiniStatus_t infiniopDestroySwiGLUCudaDescriptor(infiniopSwiGLUCudaDescriptor_t desc); + +#endif diff --git a/include/infinirt.h b/include/infinirt.h index ba16c19b2..40442e75a 100644 --- a/include/infinirt.h +++ b/include/infinirt.h @@ -6,6 +6,9 @@ typedef void *infinirtStream_t; typedef void *infinirtEvent_t; +typedef void *infinirtGraph_t; +typedef void *infinirtGraphNode_t; +typedef void *infinirtGraphExec_t; __C __export infiniStatus_t infinirtInit(); @@ -63,4 +66,24 @@ __C __export infiniStatus_t infinirtMemcpyAsync(void *dst, const void *src, size __C __export infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream_t stream); __C __export infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream); +// Graph +typedef enum { + INFINIRT_STREAM_CAPTURE_MODE_GLOBAL = 0, + INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1, + INFINIRT_STREAM_CAPTURE_MODE_RELAXED = 2, + +} infinirtStreamCaptureMode_t; + +__C __export infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode); +__C __export infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr); +__C __export infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph); +__C __export infiniStatus_t infinirtGraphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size); +__C __export infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec); +__C __export infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream); + #endif // __INFINIRT_API_H__ diff --git a/src/infinicore/context/allocators/pinnable_block_allocator.cc b/src/infinicore/context/allocators/pinnable_block_allocator.cc index f41800d7c..5574374a8 100644 --- a/src/infinicore/context/allocators/pinnable_block_allocator.cc +++ b/src/infinicore/context/allocators/pinnable_block_allocator.cc @@ -52,9 +52,19 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) { if (size <= cls.block_size) { if (!cls.free_blocks.empty()) { block = cls.free_blocks.back(); - cls.free_blocks.pop_back(); - block->in_use = true; - return reinterpret_cast(block->ptr); + while (block != nullptr && block->in_use) { + cls.free_blocks.pop_back(); + if (cls.free_blocks.empty()) { + block = nullptr; + break; + } + block = cls.free_blocks.back(); + } + if (block != nullptr) { + cls.free_blocks.pop_back(); + block->in_use = true; + return reinterpret_cast(block->ptr); + } } // Allocate a new block for this class block = std::make_shared(); @@ -125,6 +135,16 @@ void PinnableBlockAllocator::deallocate(std::byte *ptr) { } } +size_t PinnableBlockAllocator::mark_in_use_(void *ptr, bool in_use) { + auto it = all_blocks_.find(reinterpret_cast(ptr)); + if (it == all_blocks_.end()) { + throw std::runtime_error("Pointer not allocated by this allocator"); + } + std::lock_guard lock(mutex_); + it->second->in_use = in_use; + return it->second->size; +} + // ------------------- trim ------------------- void PinnableBlockAllocator::trim() { std::lock_guard lock(mutex_); diff --git a/src/infinicore/context/allocators/pinnable_block_allocator.hpp b/src/infinicore/context/allocators/pinnable_block_allocator.hpp index 8911d2a6d..4ab4b4a31 100644 --- a/src/infinicore/context/allocators/pinnable_block_allocator.hpp +++ b/src/infinicore/context/allocators/pinnable_block_allocator.hpp @@ -32,6 +32,10 @@ class PinnableBlockAllocator : public MemoryAllocator { // Switch pinned/graph mode void set_pin_mode(bool pinned) { pinned_mode_ = pinned; } + // internal use only, force set in_use flag for a mem block + // return the size of the block + size_t mark_in_use_(void *ptr, bool in_use); + // trim cached blocks back to GPU (not pinned) void trim(); diff --git a/src/infinicore/context/context_impl.cc b/src/infinicore/context/context_impl.cc index 6ed79af54..67472b067 100644 --- a/src/infinicore/context/context_impl.cc +++ b/src/infinicore/context/context_impl.cc @@ -1,4 +1,5 @@ #include "context_impl.hpp" +#include "internal.hpp" #include "../utils.hpp" @@ -194,6 +195,12 @@ void addGraphOperator(std::shared_ptr op) { std::shared_ptr stopGraphRecording() { return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording(); } + +std::shared_ptr reinstantiateBlob(std::shared_ptr blob) { + setDevice(blob->device()); + return ContextImpl::singleton().getCurrentRuntime()->reinstantiateBlob(blob); +} + } // namespace context } // namespace infinicore diff --git a/src/infinicore/context/internal.hpp b/src/infinicore/context/internal.hpp new file mode 100644 index 000000000..aeecaff51 --- /dev/null +++ b/src/infinicore/context/internal.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "infinicore/device.hpp" +#include "infinicore/memory.hpp" + +#include "infinicore/graph/graph.hpp" + +namespace infinicore::context { +std::shared_ptr reinstantiateBlob(std::shared_ptr blob); +}; diff --git a/src/infinicore/context/runtime/runtime.cc b/src/infinicore/context/runtime/runtime.cc index a6dd7eb7e..5a6f5b9c3 100644 --- a/src/infinicore/context/runtime/runtime.cc +++ b/src/infinicore/context/runtime/runtime.cc @@ -77,6 +77,15 @@ std::shared_ptr Runtime::allocatePinnedHostMemory(size_t size) { true); } +std::shared_ptr Runtime::reinstantiateBlob(std::shared_ptr blob) { + device_memory_allocator_.get()->mark_in_use_(blob->data(), true); + return std::make_shared( + blob->data(), blob->size(), device_, + [alloc = device_memory_allocator_.get()](std::byte *p) { + alloc->deallocate(p); + }); +} + void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) { if (async) { INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_)); diff --git a/src/infinicore/context/runtime/runtime.hpp b/src/infinicore/context/runtime/runtime.hpp index 58d8bd424..b5a90a602 100644 --- a/src/infinicore/context/runtime/runtime.hpp +++ b/src/infinicore/context/runtime/runtime.hpp @@ -37,6 +37,7 @@ class Runtime { std::shared_ptr allocateMemory(size_t size); std::shared_ptr allocatePinnedHostMemory(size_t size); + std::shared_ptr reinstantiateBlob(std::shared_ptr blob); void memcpyH2D(void *dst, const void *src, size_t size, bool async = true); void memcpyD2H(void *dst, const void *src, size_t size); diff --git a/src/infinicore/graph/graph.cc b/src/infinicore/graph/graph.cc index 86944af36..8218b1b48 100644 --- a/src/infinicore/graph/graph.cc +++ b/src/infinicore/graph/graph.cc @@ -1,6 +1,8 @@ #include "graph_manager.hpp" #include "../utils.hpp" +#include "infinicore/context/context.hpp" +#include namespace infinicore::graph { @@ -8,7 +10,7 @@ namespace infinicore::graph { * GraphTensor * ========================= */ -GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) { +GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) { } /* ========================= @@ -29,9 +31,40 @@ GraphOperator::~GraphOperator() { * Graph * ========================= */ +struct Graph::DeviceGraph { + infinirtGraph_t graph; + infinirtGraphExec_t exec; + infinirtGraphNode_t node; + std::vector log_buffer; + + DeviceGraph() { + log_buffer.resize(4 * 1024); + } + + ~DeviceGraph() { + if (exec) { + infinirtGraphExecDestroy(exec); + } + if (graph) { + infinirtGraphDestroy(graph); + } + } + + void launch() { + INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, context::getStream())); + } +}; + +Graph::Graph() { +} + void Graph::run() const { - for (auto &op : op_list_) { - op->run(); + if (device_graph_ != nullptr && device_graph_.get()->exec != nullptr) { + device_graph_.get()->launch(); + } else { + for (auto &op : op_list_) { + op->run(); + } } } @@ -39,6 +72,50 @@ void Graph::add_operator(std::shared_ptr op) { op_list_.push_back(op); } +void Graph::instantiate() { + // Reset device graph + device_graph_ = std::make_unique(); + + // warmup + for (size_t iter = 0; iter < 5; ++iter) { + this->run(); + } + infinicore::context::syncStream(); + + if (infinirtStreamBeginCapture( + context::getStream(), + INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) + != INFINI_STATUS_SUCCESS) { + return; + } + + // Run and record + this->run(); + + if (infinirtStreamEndCapture( + context::getStream(), + &device_graph_.get()->graph) + != INFINI_STATUS_SUCCESS) { + return; + } + + if (infinirtGraphInstantiate( + &device_graph_.get()->exec, + device_graph_.get()->graph, + &device_graph_.get()->node, + device_graph_.get()->log_buffer.data(), + device_graph_.get()->log_buffer.size()) + != INFINI_STATUS_SUCCESS) { + static bool warned_once = false; + if (!warned_once) { + warned_once = true; + spdlog::warn("Fail to instantiate device graph: {}", std::string(device_graph_.get()->log_buffer.data())); + } + } +} + +Graph::~Graph() = default; + /* ========================= * GraphManager * ========================= */ @@ -48,19 +125,26 @@ bool GraphManager::is_recording() const { } void GraphManager::start_recording() { + if (is_recording()) { + spdlog::warn("Graph is already recording. Previous recording will be dropped."); + } recording_ = true; graph_ = std::make_shared(); } void GraphManager::add_operator(std::shared_ptr op) { - INFINICORE_ASSERT(recording_); + INFINICORE_ASSERT(is_recording()); graph_->add_operator(op); } std::shared_ptr GraphManager::stop_recording() { - + if (!is_recording()) { + spdlog::warn("Graph is not recording. Please start recording first."); + return nullptr; + } recording_ = false; + graph_->instantiate(); return std::exchange(graph_, nullptr); } diff --git a/src/infinicore/tensor/copy.cc b/src/infinicore/tensor/copy.cc index 995187a12..1297d9f8c 100644 --- a/src/infinicore/tensor/copy.cc +++ b/src/infinicore/tensor/copy.cc @@ -38,7 +38,7 @@ void TensorImpl::copy_from(Tensor src) { } else { auto local_src = Tensor::empty(this->shape(), this->dtype(), this->device()); context::setDevice(src->device()); - context::memcpyD2H(local_src->data(), src->data(), this->data_.memory->size()); + context::memcpyD2H(local_src->data(), src->data(), copy_size); op::rearrange_(Tensor(const_cast(this)->shared_from_this()), local_src); } } else if (src->device().getType() == Device::Type::CPU) { diff --git a/src/infinicore/tensor/tensor.cc b/src/infinicore/tensor/tensor.cc index 2acc6dec8..34a9af601 100644 --- a/src/infinicore/tensor/tensor.cc +++ b/src/infinicore/tensor/tensor.cc @@ -1,4 +1,5 @@ #include "infinicore/tensor.hpp" +#include "../context/internal.hpp" #include "../utils.hpp" #include "infinicore/context/context.hpp" #include "infinicore/dtype.hpp" @@ -275,10 +276,22 @@ std::shared_ptr TensorImpl::strided_from_blob( return t; } -Tensor TensorImpl::to_blob() const { +Tensor TensorImpl::to_blob_() const { auto t = std::shared_ptr(new TensorImpl(shape(), strides(), dtype())); t->data_.offset = this->data_.offset; t->data_.memory = std::make_shared(this->data_.memory->data(), this->data_.memory->size(), this->data_.memory->device(), nullptr); + t->to_blob_mark_ = true; + return Tensor{t}; +} + +Tensor TensorImpl::resume_from_blob_() const { + auto t = std::shared_ptr(new TensorImpl(shape(), strides(), dtype())); + t->data_.offset = this->data_.offset; + if (to_blob_mark_) { + t->data_.memory = context::reinstantiateBlob(this->data_.memory); + } else { + t->data_.memory = this->data_.memory; + } return Tensor{t}; } diff --git a/src/infiniop/devices/nvidia/nvidia_common.cu b/src/infiniop/devices/nvidia/nvidia_common.cu index f712b6053..7c2369f1c 100644 --- a/src/infiniop/devices/nvidia/nvidia_common.cu +++ b/src/infiniop/devices/nvidia/nvidia_common.cu @@ -23,6 +23,10 @@ Handle::Internal::Internal(int device_id) { _grid_size[0] = prop.maxGridSize[0]; _grid_size[1] = prop.maxGridSize[1]; _grid_size[2] = prop.maxGridSize[2]; + this->useCublas(nullptr, [](cublasHandle_t handle) { return INFINI_STATUS_SUCCESS; }); +#ifdef ENABLE_CUDNN_API + this->useCudnn(nullptr, [](cudnnHandle_t handle) { return INFINI_STATUS_SUCCESS; }); +#endif } infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn &f) const { diff --git a/src/infiniop/ops/swiglu_cuda/cuda/kernel.cuh b/src/infiniop/ops/swiglu_cuda/cuda/kernel.cuh new file mode 100644 index 000000000..3dc627e10 --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/cuda/kernel.cuh @@ -0,0 +1,78 @@ +#ifndef __SWIGLU_CUDA_KERNEL_CUH__ +#define __SWIGLU_CUDA_KERNEL_CUH__ +template +__device__ __forceinline__ T sigmoid(const T &x) { + if constexpr (std::is_same_v) { + return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x)))); + } else if constexpr (std::is_same_v) { + return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0))); + float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1))); + return __floats2bfloat162_rn(sig0, sig1); + } else if constexpr (std::is_same_v) { + float xf = __bfloat162float(x); + return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf)))); + } else if constexpr (std::is_same_v) { + return __frcp_rn(__fadd_rn(1, __expf(-x))); + } else { + return 1 / (1 + std::exp(-x)); + } +} + +template +__device__ void SwiGLUCudaKernel( + T *c, + const T *a, + const T *b, + int length, + const size_t *shape, + const ptrdiff_t *c_strides, + const ptrdiff_t *a_strides, + const ptrdiff_t *b_strides, + int ndim) { + int ind_c = 0; + int ind_a = 0; + int ind_b = 0; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < length) { + for (int j = ndim - 1; j >= 0; j--) { + ind_c += (tid % (int)shape[j]) * (int)c_strides[j]; + ind_a += (tid % (int)shape[j]) * (int)a_strides[j]; + ind_b += (tid % (int)shape[j]) * (int)b_strides[j]; + tid = tid / (int)shape[j]; + } + T gate = b[ind_b]; + T up = a[ind_a]; + if constexpr (std::is_same_v) { + c[ind_c] = __hmul2(__hmul2(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + c[ind_c] = __hmul(__hmul(gate, sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + cuda_bfloat162 sig = sigmoid(gate); + float gate0 = __bfloat162float(__low2bfloat16(gate)); + float gate1 = __bfloat162float(__high2bfloat16(gate)); + float sig0 = __bfloat162float(__low2bfloat16(sig)); + float sig1 = __bfloat162float(__high2bfloat16(sig)); + float up0 = __bfloat162float(__low2bfloat16(up)); + float up1 = __bfloat162float(__high2bfloat16(up)); + float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0); + float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1); + c[ind_c] = __floats2bfloat162_rn(res0, res1); + } else if constexpr (std::is_same_v) { + cuda_bfloat16 sig = sigmoid(gate); + float gatef = __bfloat162float(gate); + float sigf = __bfloat162float(sig); + float upf = __bfloat162float(up); + c[ind_c] = __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf)); + } else if constexpr (std::is_same_v) { + c[ind_c] = __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up); + } else { + c[ind_c] = gate * sigmoid(gate) * up; + } + } +} + +#endif // __SWIGLU_CUDA_KERNEL_CUH__ diff --git a/src/infiniop/ops/swiglu_cuda/info.h b/src/infiniop/ops/swiglu_cuda/info.h new file mode 100644 index 000000000..6d06768df --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/info.h @@ -0,0 +1,52 @@ +#ifndef __SWIGLU_CUDA_INFO_H__ +#define __SWIGLU_CUDA_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" + +namespace op::swiglu_cuda { + +class SwiGLUCudaInfo { + SwiGLUCudaInfo() = default; + +public: + infiniDtype_t dtype; + size_t length; + size_t ndim; + std::vector shape; + std::vector c_strides; + std::vector a_strides; + std::vector b_strides; + + static utils::Result createSwiGLUCudaInfo(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { + auto dtype = c_desc->dtype(); + if (dtype != a_desc->dtype() || dtype != b_desc->dtype()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + auto shape = c_desc->shape(); + CHECK_SAME_SHAPE(shape, a_desc->shape(), b_desc->shape()); + + auto ndim = c_desc->ndim(); + + size_t length = 1; + for (int i = 0; i < (int)ndim; i++) { + length *= shape[i]; + } + + return utils::Result(SwiGLUCudaInfo{ + dtype, + length, + ndim, + shape, + c_desc->strides(), + a_desc->strides(), + b_desc->strides()}); + } +}; + +} // namespace op::swiglu_cuda + +#endif // __SWIGLU_CUDA_INFO_H__ diff --git a/src/infiniop/ops/swiglu_cuda/metax/swiglu_cuda_metax.cuh b/src/infiniop/ops/swiglu_cuda/metax/swiglu_cuda_metax.cuh new file mode 100644 index 000000000..5f0c404f8 --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/metax/swiglu_cuda_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __SWIGLU_CUDA_METAX_H__ +#define __SWIGLU_CUDA_METAX_H__ + +#include "../swiglu_cuda.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/swiglu_cuda/metax/swiglu_cuda_metax.maca b/src/infiniop/ops/swiglu_cuda/metax/swiglu_cuda_metax.maca new file mode 100644 index 000000000..12d1a922b --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/metax/swiglu_cuda_metax.maca @@ -0,0 +1,113 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" + +#include "../cuda/kernel.cuh" +#include "swiglu_cuda_metax.cuh" + +template +INFINIOP_METAX_KERNEL SwiGLUCuda( + T *c, + const T *a, + const T *b, + int length, + const size_t *shape, + const ptrdiff_t *c_strides, + const ptrdiff_t *a_strides, + const ptrdiff_t *b_strides, + int ndim) { + SwiGLUCudaKernel(c, a, b, length, shape, c_strides, a_strides, b_strides, ndim); +} + +namespace op::swiglu_cuda::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + + auto info = SwiGLUCudaInfo::createSwiGLUCudaInfo(c_desc, a_desc, b_desc); + CHECK_RESULT(info); + size_t WorkSpaceSize = c_desc->ndim() * (sizeof(ptrdiff_t) * 3 + sizeof(size_t)); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), WorkSpaceSize, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculate_swiglu_cuda( + const SwiGLUCudaInfo &info, + T *c, + const T *a, + const T *b, + cudaStream_t stream, + void *workspace) { + int ndim = (int)info.ndim; + char *workspace_ptr = reinterpret_cast(workspace); + ptrdiff_t *c_strides_cuda = reinterpret_cast(workspace_ptr); + ptrdiff_t *a_strides_cuda = c_strides_cuda + ndim; + ptrdiff_t *b_strides_cuda = a_strides_cuda + ndim; + + size_t ptrdiff_array_size = 3 * ndim * sizeof(ptrdiff_t); + size_t *shape_cuda = reinterpret_cast(workspace_ptr + ptrdiff_array_size); + + CHECK_CUDA(cudaMemcpyAsync(c_strides_cuda, info.c_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(a_strides_cuda, info.a_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(b_strides_cuda, info.b_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); + + CHECK_CUDA(cudaMemcpyAsync(shape_cuda, info.shape.data(), sizeof(size_t) * ndim, cudaMemcpyHostToDevice, stream)); + int length = (int)info.length; + int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; + SwiGLUCuda + <<>>(c, a, b, length, shape_cuda, c_strides_cuda, a_strides_cuda, b_strides_cuda, ndim); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream_) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + cudaStream_t stream = (cudaStream_t)stream_; + +#define CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, TDATA) \ + calculate_swiglu_cuda(_info, (TDATA *)c, (const TDATA *)a, (const TDATA *)b, stream, workspace) +#define CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, __hpcc_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::swiglu_cuda::nvidia diff --git a/src/infiniop/ops/swiglu_cuda/nvidia/swiglu_cuda_nvidia.cu b/src/infiniop/ops/swiglu_cuda/nvidia/swiglu_cuda_nvidia.cu new file mode 100644 index 000000000..e222569b9 --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/nvidia/swiglu_cuda_nvidia.cu @@ -0,0 +1,116 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/kernel.cuh" +#include "swiglu_cuda_nvidia.cuh" + +template +INFINIOP_CUDA_KERNEL SwiGLUCuda( + T *c, + const T *a, + const T *b, + int length, + const size_t *shape, + const ptrdiff_t *c_strides, + const ptrdiff_t *a_strides, + const ptrdiff_t *b_strides, + int ndim) { + SwiGLUCudaKernel(c, a, b, length, shape, c_strides, a_strides, b_strides, ndim); +} + +namespace op::swiglu_cuda::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + + auto info = SwiGLUCudaInfo::createSwiGLUCudaInfo(c_desc, a_desc, b_desc); + CHECK_RESULT(info); + size_t WorkSpaceSize = c_desc->ndim() * (sizeof(ptrdiff_t) * 3 + sizeof(size_t)); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), WorkSpaceSize, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculate_swiglu_cuda( + const SwiGLUCudaInfo &info, + T *c, + const T *a, + const T *b, + cudaStream_t stream, + void *workspace) { + int ndim = (int)info.ndim; + char *workspace_ptr = reinterpret_cast(workspace); + ptrdiff_t *c_strides_cuda = reinterpret_cast(workspace_ptr); + ptrdiff_t *a_strides_cuda = c_strides_cuda + ndim; + ptrdiff_t *b_strides_cuda = a_strides_cuda + ndim; + + size_t ptrdiff_array_size = 3 * ndim * sizeof(ptrdiff_t); + size_t *shape_cuda = reinterpret_cast(workspace_ptr + ptrdiff_array_size); + + CHECK_CUDA(cudaMemcpyAsync(c_strides_cuda, info.c_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(a_strides_cuda, info.a_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(b_strides_cuda, info.b_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); + + CHECK_CUDA(cudaMemcpyAsync(shape_cuda, info.shape.data(), sizeof(size_t) * ndim, cudaMemcpyHostToDevice, stream)); + int length = (int)info.length; + int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; + SwiGLUCuda + <<>>(c, a, b, length, shape_cuda, c_strides_cuda, a_strides_cuda, b_strides_cuda, ndim); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream_) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + cudaStream_t stream = (cudaStream_t)stream_; + +#define CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, TDATA) \ + calculate_swiglu_cuda(_info, (TDATA *)c, (const TDATA *)a, (const TDATA *)b, stream, workspace) +#define CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, __nv_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024) + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512) + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::swiglu_cuda::nvidia diff --git a/src/infiniop/ops/swiglu_cuda/nvidia/swiglu_cuda_nvidia.cuh b/src/infiniop/ops/swiglu_cuda/nvidia/swiglu_cuda_nvidia.cuh new file mode 100644 index 000000000..32ca31c8a --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/nvidia/swiglu_cuda_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __SWIGLU_CUDA_NVIDIA_H__ +#define __SWIGLU_CUDA_NVIDIA_H__ + +#include "../swiglu_cuda.h" + +DESCRIPTOR(nvidia) + +#endif diff --git a/src/infiniop/ops/swiglu_cuda/operator.cc b/src/infiniop/ops/swiglu_cuda/operator.cc new file mode 100644 index 000000000..e8c24e799 --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/operator.cc @@ -0,0 +1,113 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/swiglu_cuda.h" + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#include "nvidia/swiglu_cuda_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/swiglu_cuda_metax.cuh" +#endif + +__C infiniStatus_t infiniopCreateSwiGLUCudaDescriptor( + infiniopHandle_t handle, + infiniopSwiGLUCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::swiglu_cuda::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + a_desc, \ + b_desc) + + switch (handle->device) { + +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetSwiGLUCudaWorkspaceSize(infiniopSwiGLUCudaDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__C infiniStatus_t infiniopSwiGLUCuda( + infiniopSwiGLUCudaDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, a, b, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroySwiGLUCudaDescriptor(infiniopSwiGLUCudaDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/swiglu_cuda/swiglu_cuda.h b/src/infiniop/ops/swiglu_cuda/swiglu_cuda.h new file mode 100644 index 000000000..ac975278d --- /dev/null +++ b/src/infiniop/ops/swiglu_cuda/swiglu_cuda.h @@ -0,0 +1,48 @@ +#ifndef SWIGLU_CUDA_H +#define SWIGLU_CUDA_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::swiglu_cuda::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + SwiGLUCudaInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + SwiGLUCudaInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t c_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *c, \ + const void *a, \ + const void *b, \ + void *stream) const; \ + }; \ + } + +#endif // SWIGLU_CUDA_H diff --git a/src/infinirt/ascend/infinirt_ascend.cc b/src/infinirt/ascend/infinirt_ascend.cc index 4731f086a..b12fbc89b 100644 --- a/src/infinirt/ascend/infinirt_ascend.cc +++ b/src/infinirt/ascend/infinirt_ascend.cc @@ -150,5 +150,35 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) { infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { return freeDevice(ptr); } + +infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphDestroy(infinirtGraph_t graph) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + } // namespace infinirt::ascend #undef CHECK_ACLRT diff --git a/src/infinirt/bang/infinirt_bang.cc b/src/infinirt/bang/infinirt_bang.cc index bccbb7c19..5384add19 100644 --- a/src/infinirt/bang/infinirt_bang.cc +++ b/src/infinirt/bang/infinirt_bang.cc @@ -142,4 +142,34 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { CHECK_BANGRT(cnrtFree(ptr)); return INFINI_STATUS_SUCCESS; } + +infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphDestroy(infinirtGraph_t graph) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + } // namespace infinirt::bang diff --git a/src/infinirt/cpu/infinirt_cpu.cc b/src/infinirt/cpu/infinirt_cpu.cc index c8709b1d4..06468ec25 100644 --- a/src/infinirt/cpu/infinirt_cpu.cc +++ b/src/infinirt/cpu/infinirt_cpu.cc @@ -116,4 +116,33 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { return freeDevice(ptr); } +infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphDestroy(infinirtGraph_t graph) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + } // namespace infinirt::cpu diff --git a/src/infinirt/cuda/infinirt_cuda.cu b/src/infinirt/cuda/infinirt_cuda.cu index dadd5734c..697e47646 100644 --- a/src/infinirt/cuda/infinirt_cuda.cu +++ b/src/infinirt/cuda/infinirt_cuda.cu @@ -176,4 +176,53 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { RUN_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream)); return INFINI_STATUS_SUCCESS; } + +infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + cudaStreamCaptureMode graph_mode; + if (mode == INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) { + graph_mode = cudaStreamCaptureModeGlobal; + } else if (mode == INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL) { + graph_mode = cudaStreamCaptureModeThreadLocal; + } else if (mode == INFINIRT_STREAM_CAPTURE_MODE_RELAXED) { + graph_mode = cudaStreamCaptureModeRelaxed; + } else { + return INFINI_STATUS_BAD_PARAM; + } + + CHECK_CUDART(cudaStreamBeginCapture((cudaStream_t)stream, graph_mode)); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + cudaGraph_t graph; + CHECK_CUDART(cudaStreamEndCapture((cudaStream_t)stream, &graph)); + *graph_ptr = graph; + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t graphDestroy(infinirtGraph_t graph) { + RUN_CUDART(cudaGraphDestroy((cudaGraph_t)graph)); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t graphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + CHECK_CUDART(cudaGraphInstantiate((cudaGraphExec_t *)graph_exec_ptr, (cudaGraph_t)graph, (cudaGraphNode_t *)node_ptr, log_buffer, buffer_size)); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { + RUN_CUDART(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec)); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + CHECK_CUDART(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream)); + return INFINI_STATUS_SUCCESS; +} } diff --git a/src/infinirt/infinirt.cc b/src/infinirt/infinirt.cc index 82bb5ab28..e16f1c0f4 100644 --- a/src/infinirt/infinirt.cc +++ b/src/infinirt/infinirt.cc @@ -192,3 +192,32 @@ __C infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream __C infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream) { INFINIRT_CALL_DEVICE_API(freeAsync, (ptr, stream)); } + +__C infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + INFINIRT_CALL_DEVICE_API(streamBeginCapture, (stream, mode)); +} + +__C infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + INFINIRT_CALL_DEVICE_API(streamEndCapture, (stream, graph_ptr)); +} + +__C infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph) { + INFINIRT_CALL_DEVICE_API(graphDestroy, (graph)); +} + +__C infiniStatus_t infinirtGraphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + INFINIRT_CALL_DEVICE_API(graphInstantiate, (graph_exec_ptr, graph, node_ptr, log_buffer, buffer_size)); +} + +__C infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec) { + INFINIRT_CALL_DEVICE_API(graphExecDestroy, (graph_exec)); +} + +__C infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + INFINIRT_CALL_DEVICE_API(graphLuanch, (graph_exec, stream)); +} diff --git a/src/infinirt/infinirt_impl.h b/src/infinirt/infinirt_impl.h index 8ae4347c5..c30d078d5 100644 --- a/src/infinirt/infinirt_impl.h +++ b/src/infinirt/infinirt_impl.h @@ -30,7 +30,19 @@ INLINE infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream) IMPL; \ \ INLINE infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) IMPL; \ - INLINE infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) IMPL; + INLINE infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) IMPL; \ + \ + INLINE infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) IMPL; \ + INLINE infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) IMPL; \ + INLINE infiniStatus_t graphDestroy(infinirtGraph_t graph) IMPL; \ + INLINE infiniStatus_t graphInstantiate( \ + infinirtGraphExec_t *graph_exec_ptr, \ + infinirtGraph_t graph, \ + infinirtGraphNode_t *node_ptr, \ + char *log_buffer, \ + size_t buffer_size) IMPL; \ + INLINE infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) IMPL; \ + INLINE infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) IMPL; #define INFINIRT_DEVICE_API_IMPL INFINIRT_DEVICE_API(, , ) #define INFINIRT_DEVICE_API_NOOP INFINIRT_DEVICE_API( \ diff --git a/src/infinirt/kunlun/infinirt_kunlun.cc b/src/infinirt/kunlun/infinirt_kunlun.cc index f2fe43680..dc1b34a27 100644 --- a/src/infinirt/kunlun/infinirt_kunlun.cc +++ b/src/infinirt/kunlun/infinirt_kunlun.cc @@ -153,4 +153,33 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphDestroy(infinirtGraph_t graph) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + } // namespace infinirt::kunlun diff --git a/src/infinirt/metax/infinirt_metax.cc b/src/infinirt/metax/infinirt_metax.cc index bd7131e9d..aca187366 100644 --- a/src/infinirt/metax/infinirt_metax.cc +++ b/src/infinirt/metax/infinirt_metax.cc @@ -152,4 +152,34 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { CHECK_MACART(hcFreeAsync(ptr, (hcStream_t)stream)); return INFINI_STATUS_SUCCESS; } + +infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphDestroy(infinirtGraph_t graph) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + } // namespace infinirt::metax diff --git a/src/infinirt/moore/infinirt_moore.cc b/src/infinirt/moore/infinirt_moore.cc index b1861c0f2..a7ad61c6b 100644 --- a/src/infinirt/moore/infinirt_moore.cc +++ b/src/infinirt/moore/infinirt_moore.cc @@ -138,4 +138,34 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) { infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { return freeDevice(ptr); } + +infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphDestroy(infinirtGraph_t graph) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphInstantiate( + infinirtGraphExec_t *graph_exec_ptr, + infinirtGraph_t graph, + infinirtGraphNode_t *node_ptr, + char *log_buffer, + size_t buffer_size) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) { + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + } // namespace infinirt::musa diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 618be2b05..1596829d9 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -560,6 +560,39 @@ def swiglu_(lib): ] +@OpRegister.operator +def swiglu_cuda_(lib): + lib.infiniopCreateSwiGLUCudaDescriptor.restype = c_int32 + lib.infiniopCreateSwiGLUCudaDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetSwiGLUCudaWorkspaceSize.restype = c_int32 + lib.infiniopGetSwiGLUCudaWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSwiGLUCuda.restype = c_int32 + lib.infiniopSwiGLUCuda.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySwiGLUCudaDescriptor.restype = c_int32 + lib.infiniopDestroySwiGLUCudaDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + @OpRegister.operator def conv_(lib): lib.infiniopCreateConvDescriptor.restype = c_int32 diff --git a/test/infiniop/swiglu_cuda.py b/test/infiniop/swiglu_cuda.py new file mode 100644 index 000000000..d62d538cb --- /dev/null +++ b/test/infiniop/swiglu_cuda.py @@ -0,0 +1,183 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, a_stride, b_stride, c_stride + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), +] + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_A = auto() + INPLACE_B = auto() + + +# Inplace options applied for each test case in _TEST_CASES_ +_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE_A, + Inplace.INPLACE_B, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 1e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def swiglu_cuda(a, b): + return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) + + +def test( + handle, + device, + shape, + a_stride=None, + b_stride=None, + c_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=InfiniDtype.F16, + sync=None, +): + a = TestTensor(shape, a_stride, dtype, device) + b = TestTensor(shape, b_stride, dtype, device) + if inplace == Inplace.INPLACE_A: + if c_stride is not None and c_stride != a_stride: + return + c = a + elif inplace == Inplace.INPLACE_B: + if c_stride is not None and c_stride != b_stride: + return + c = b + else: + c = TestTensor(shape, c_stride, dtype, device) + + if c.is_broadcast(): + return + + print( + f"Testing SwiGLUCuda on {InfiniDeviceNames[device]} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} " + f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" + ) + + ans = swiglu_cuda(a.torch_tensor(), b.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateSwiGLUCudaDescriptor( + handle, + ctypes.byref(descriptor), + c.descriptor, + a.descriptor, + b.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a, b, c]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetSwiGLUCudaWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, c.device) + + def lib_swiglu_cuda(): + check_error( + LIBINFINIOP.infiniopSwiGLUCuda( + descriptor, + workspace.data(), + workspace_size.value, + c.data(), + a.data(), + b.data(), + None, + ) + ) + + lib_swiglu_cuda() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(c.actual_tensor(), ans, atol=atol, rtol=rtol) + + assert torch.allclose(c.actual_tensor(), ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: swiglu_cuda(a.torch_tensor(), b.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_swiglu_cuda(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroySwiGLUCudaDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index a86090776..75086b8a1 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -49,6 +49,7 @@ target("infiniop-nvidia") add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") add_cxxflags("-fPIC") + add_cflags("-fPIC") add_cuflags("--expt-relaxed-constexpr") if CUDNN_ROOT ~= nil then add_linkdirs(CUDNN_ROOT .. "/lib")