Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove task inf_type
  • Loading branch information
ngxson committed Dec 7, 2024
commit 090a113417117b90c7c8a946316caa6d9cba7544
68 changes: 35 additions & 33 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ enum server_state {
};

enum server_task_type {
SERVER_TASK_TYPE_INFERENCE,
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_EMBEDDING,
SERVER_TASK_TYPE_RERANK,
SERVER_TASK_TYPE_INFILL,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
Expand All @@ -64,13 +67,6 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA,
};

enum server_task_inf_type {
SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_INF_TYPE_EMBEDDING,
SERVER_TASK_INF_TYPE_RERANK,
SERVER_TASK_INF_TYPE_INFILL,
};

// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_TYPE_INVALID_REQUEST,
Expand Down Expand Up @@ -163,8 +159,7 @@ struct server_task {
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)

server_task_type type;
server_task_inf_type inf_type;
server_task_type type;

// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
Expand All @@ -185,9 +180,7 @@ struct server_task {
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;

server_task(
server_task_type type,
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {}
server_task(server_task_type type) : type(type) {}

static slot_params params_from_json_cmpl(
const llama_model * model,
Expand Down Expand Up @@ -893,6 +886,9 @@ struct server_slot {
int id;
int id_task = -1;

// only used for completion/embedding/infill/rerank
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;

llama_batch batch_spec = {};

llama_context * ctx = nullptr;
Expand Down Expand Up @@ -931,8 +927,6 @@ struct server_slot {
llama_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs;

server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;

bool has_next_token = true;
bool has_new_line = false;
bool truncated = false;
Expand Down Expand Up @@ -972,11 +966,15 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
task_type = SERVER_TASK_TYPE_COMPLETION;

generated_token_probs.clear();
}

bool is_non_causal() const {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}

bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
Expand Down Expand Up @@ -1088,6 +1086,7 @@ struct server_slot {
{"n_ctx", n_ctx},
{"speculative", can_speculate()},
{"is_processing", is_processing()},
{"non_causal", is_non_causal()},
{"params", params.to_json()},
{"prompt", common_detokenize(ctx, prompt_tokens)},
{"next_token",
Expand Down Expand Up @@ -1653,8 +1652,8 @@ struct server_context {
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot.reset();
slot.id_task = task.id;
slot.inf_type = task.inf_type;
slot.index = task.index;
slot.task_type = task.type;
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);

Expand Down Expand Up @@ -2120,7 +2119,10 @@ struct server_context {

void process_single_task(server_task task) {
switch (task.type) {
case SERVER_TASK_TYPE_INFERENCE:
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
{
const int id_slot = task.id_selected_slot;

Expand Down Expand Up @@ -2462,7 +2464,7 @@ struct server_context {
continue;
}

if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
if (slot.is_non_causal()) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
Expand Down Expand Up @@ -2577,18 +2579,15 @@ struct server_context {
}

// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
}
}

// check that we are in the right batch_type, if not defer the slot
const bool slot_type =
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;

int slot_type = slot.is_non_causal();
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
Expand Down Expand Up @@ -2705,15 +2704,15 @@ struct server_context {
}

if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
// prompt evaluated for embedding
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}

if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
Expand Down Expand Up @@ -3352,11 +3351,13 @@ int main(int argc, char ** argv) {
// handle completion-like requests (completion, chat, infill)
// we can optionally provide a custom format for partial results and final results
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
server_task_inf_type inf_type,
server_task_type type,
json & data,
httplib::Response & res,
bool oaicompat = false,
bool oaicompat_chat = false) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);

if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
Expand All @@ -3369,7 +3370,8 @@ int main(int argc, char ** argv) {
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, inf_type);
server_task task = server_task(type);

task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;

Expand Down Expand Up @@ -3450,7 +3452,7 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(
SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_TYPE_COMPLETION,
data,
res,
/* oaicompat */ false,
Expand Down Expand Up @@ -3504,7 +3506,7 @@ int main(int argc, char ** argv) {
}
data["input_extra"] = input_extra; // default to empty array if it's not exist

return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
};

const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
Expand All @@ -3515,7 +3517,7 @@ int main(int argc, char ** argv) {

json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
return handle_completions_generic(
SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_TYPE_COMPLETION,
data,
res,
/* oaicompat */ true,
Expand Down Expand Up @@ -3616,7 +3618,7 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING);
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
Expand Down Expand Up @@ -3698,7 +3700,7 @@ int main(int argc, char ** argv) {
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
Expand Down
Loading