Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor providers into separate libraries #1190

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/beam_search_scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)

// Space to store intermediate sequence
size_t const per_beam = (max_length_ * (max_length_ + 1)) / 2;
hypothesis_buffer_ = device.Allocate<int32_t>(batch_beam_size * per_beam, true);
hypothesis_buffer_ = device.Allocate<int32_t>(batch_beam_size * per_beam);

memset(next_beam_scores_.Span().data(), 0, next_beam_scores_.Span().size_bytes());

Expand Down
64 changes: 56 additions & 8 deletions src/cpu/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@

#include "../generators.h"
#include "../search.h"
#include "../models/utils.h"
#include "interface.h"

namespace Generators {

static Ort::Allocator* ort_allocator_{};
const char* label_cpu = "cpu";

struct CpuMemory final : DeviceBuffer {
CpuMemory(size_t size) : owned_{true} {
size_in_bytes_ = size;
p_cpu_ = p_device_ = new uint8_t[size_in_bytes_];
p_cpu_ = p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
}

CpuMemory(void* p, size_t size) : owned_{false} {
Expand All @@ -22,33 +24,79 @@ struct CpuMemory final : DeviceBuffer {

~CpuMemory() override {
if (owned_)
delete[] p_device_;
ort_allocator_->Free(p_device_);
}

const char* GetType() const override { return label_cpu; }
void AllocateCpu() override {} // Nothing to do, device is also CPU
void CopyDeviceToCpu() override {} // Nothing to do, device is also CPU
void CopyCpuToDevice() override {} // Nothing to do, device is also CPU
void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
if (GetType() == label_cpu)
memcpy(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes);
else
throw std::runtime_error("CpuMemory::CopyFromDevice not implemented for " + std::string(source.GetType()));
CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
}

void Zero() override {
memset(p_device_, 0, size_in_bytes_);
}

bool owned_;
};

struct CpuInterface : DeviceInterface {
std::shared_ptr<DeviceBuffer> AllocateBase(size_t size, bool cpu_accessible) override {
// cpu_accessible is ignored, as with the cpu, the device is also the cpu
CpuInterface() {
InitOrt(*Ort::api, Ort::Allocator::GetWithDefaultOptions());
}

void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
assert(!ort_allocator_);
ort_allocator_ = &allocator;
}

Ort::Allocator& GetAllocator() override {
return *ort_allocator_;
}

std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
return std::make_shared<CpuMemory>(size);
}

std::shared_ptr<DeviceBuffer> WrapMemoryBase(void* p, size_t size) override {
return std::make_shared<CpuMemory>(p, size);
}

bool Cast(OrtValue& input, OrtValue& output) override {
auto input_info = input.GetTensorTypeAndShapeInfo();
auto output_info = output.GetTensorTypeAndShapeInfo();

auto input_type = input_info->GetElementType();
auto output_type = output_info->GetElementType();

auto element_count = input_info->GetElementCount();
if (element_count != output_info->GetElementCount())
throw std::runtime_error("Cast - input and output element counts do not match");
if (input_type == output_type)
throw std::runtime_error("Cast - input and output types are the same");

if (input_type == Ort::TypeToTensorType<float> && output_type == Ort::TypeToTensorType<Ort::Float16_t>) {
auto* fp32 = input.GetTensorData<float>();
auto* fp16 = output.GetTensorMutableData<uint16_t>();
for (size_t i = 0; i < element_count; i++)
fp16[i] = FastFloat32ToFloat16(fp32[i]);
} else if (input_type == Ort::TypeToTensorType<Ort::Float16_t> && output_type == Ort::TypeToTensorType<float>) {
auto* fp16 = input.GetTensorData<uint16_t>();
auto* fp32 = output.GetTensorMutableData<float>();
for (size_t i = 0; i < element_count; i++)
fp32[i] = FastFloat16ToFloat32(fp16[i]);
} else if (input_type == Ort::TypeToTensorType<int32_t> && output_type == Ort::TypeToTensorType<int64_t>) {
auto* input_data = input.GetTensorData<int32_t>();
auto* output_data = output.GetTensorMutableData<int64_t>();
for (size_t i = 0; i < element_count; i++)
output_data[i] = input_data[i];
} else
throw std::runtime_error("Cast - Unimplemented cast");
return true;
}

std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) override { return std::make_unique<GreedySearch_Cpu>(params); }
std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }

Expand Down
2 changes: 1 addition & 1 deletion src/cuda/beam_search_scorer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace Generators {

BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
: stream_{parameters.cuda_stream} {
: stream_{GetStream()} {
state_cpu_ = CudaMallocHostArray<cuda::BeamScorerState>(1);
state_cpu_->batch_size_ = static_cast<size_t>(parameters.search.batch_size);
state_cpu_->num_beams_ = static_cast<size_t>(parameters.search.num_beams);
Expand Down
1 change: 1 addition & 0 deletions src/cuda/beam_search_scorer_cuda.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "models/onnxruntime_api.h"
#include "smartptrs.h"

namespace Generators {
Expand Down
13 changes: 7 additions & 6 deletions src/cuda/cuda_sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "span.h"
#include "beam_search_topk.h"
#include "cuda_sampling.cuh"
#include "models/onnxruntime_api.h"
#include "smartptrs.h"
#include <cuda_runtime.h>
#include <cub/cub.cuh>
Expand Down Expand Up @@ -297,22 +298,22 @@ __global__ void SoftmaxBlockForward(outscalar_t* output, scalar_t* input, int cl
}

template <bool is_log_softmax>
void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements,
void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements,
int input_stride, int output_stride, int batch_count, float temperature) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(float);
dim3 block = SoftmaxGetBlockSize(ILP, softmax_elements);
if (is_log_softmax) {
SoftmaxBlockForward<ILP, float, float, float, LogSoftmaxForwardEpilogue>
<<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input),
<<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input),
softmax_elements, input_stride, output_stride, temperature);
} else {
SoftmaxBlockForward<ILP, float, float, float, SoftmaxForwardEpilogue>
<<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input),
<<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input),
softmax_elements, input_stride, output_stride, temperature);
}
}
template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t*, float*, const float*, int, int, int, int, float);
template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t, float*, const float*, int, int, int, int, float);

// Populate Kernels and Launchers

Expand Down Expand Up @@ -521,7 +522,7 @@ void LaunchSampleKernel(SamplingData* data, cudaStream_t stream, float* scores,
void SoftmaxAndSort(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, float temperature) {
// Softmax scores
std::span<float> scores{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
DispatchBlockwiseSoftmaxForward<false>(&stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
DispatchBlockwiseSoftmaxForward<false>(stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
// Sort indices by scores
std::span<int> offsets_gpu{data->offsets.get(), static_cast<size_t>(batch_size + 1)};
LaunchPopulateOffsets(offsets_gpu.data(), vocab_size, batch_size, stream);
Expand Down Expand Up @@ -550,7 +551,7 @@ void LaunchGetTopKSubsetFullSort(SamplingData* data, cudaStream_t stream, float*
void GetTopKSubset(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, int k, float temperature) {
// Softmax scores
std::span<float> scores_softmaxed{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
DispatchBlockwiseSoftmaxForward<false>(&stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
DispatchBlockwiseSoftmaxForward<false>(stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
// Get top k subset
#define GetTopK(max_k) \
LaunchGetTopKSubset<max_k>(stream, \
Expand Down
2 changes: 1 addition & 1 deletion src/cuda/cuda_sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void LaunchPopulateIndices(int* indices, int size, int batch_size, cudaStream_t
void GetSample(SamplingData* data, cudaStream_t stream, int32_t* d_next_token, float* d_scores, int vocab_size, int batch_size, int k, float p, float temperature);

template <bool is_log_softmax>
void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0);
void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0);

} // namespace cuda
} // namespace Generators
Loading
Loading