diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 631fc86d..a15f809d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,8 @@ on: env: MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf MODEL_NAME: codellama-7b.Q2_K.gguf + RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf + RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf jobs: build-and-test-linux: @@ -21,8 +23,12 @@ jobs: run: | mvn compile .github/build.sh -DLLAMA_VERBOSE=ON - - name: Download model + - name: Download text generation model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() @@ -53,8 +59,12 @@ jobs: run: | mvn compile .github/build.sh ${{ matrix.target.cmake }} - - name: Download model + - name: Download text generaton model model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() @@ -79,6 +89,10 @@ jobs: .github\build.bat -DLLAMA_VERBOSE=ON - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME + - name: Download reranking model + run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 04b4a147..64032028 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -11,22 +11,25 @@ on: env: MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf" MODEL_NAME: "codellama-7b.Q2_K.gguf" + RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf" + RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf" jobs: - build-linux-cuda: - name: Build Linux x86-64 CUDA12 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Build libraries - shell: bash - run: | - .github/dockcross/dockcross-manylinux_2_28-x64 .github/build_cuda_linux.sh "-DOS_NAME=Linux -DOS_ARCH=x86_64" - - name: Upload artifacts - uses: actions/upload-artifact@v4 - with: - name: linux-libraries-cuda - path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/ +# todo: doesn't work with the newest llama.cpp version +# build-linux-cuda: +# name: Build Linux x86-64 CUDA12 +# runs-on: ubuntu-latest +# steps: +# - uses: actions/checkout@v4 +# - name: Build libraries +# shell: bash +# run: | +# .github/dockcross/dockcross-manylinux_2_28-x64 .github/build_cuda_linux.sh "-DOS_NAME=Linux -DOS_ARCH=x86_64" +# - name: Upload artifacts +# uses: actions/upload-artifact@v4 +# with: +# name: linux-libraries-cuda +# path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/ build-linux-docker: name: Build ${{ matrix.target.os }}-${{ matrix.target.arch }} @@ -102,23 +105,24 @@ jobs: - { os: Windows, arch: x86_64, - cmake: '-G "Visual Studio 17 2022" -A "x64"' - } - - { - os: Windows, - arch: aarch64, - cmake: '-G "Visual Studio 17 2022" -A "ARM64"' + cmake: '-G "Visual Studio 16 2019" -A "x64"' } - { os: Windows, arch: x86, - cmake: '-G "Visual Studio 17 2022" -A "Win32"' - } - - { - os: Windows, - arch: arm, - cmake: '-G "Visual Studio 17 2022" -A "ARM"' + cmake: '-G "Visual Studio 16 2019" -A "Win32"' } +# MSVC aarch64 builds no longer work with llama.cpp (requires clang instead) +# - { +# os: Windows, +# arch: aarch64, +# cmake: '-G "Visual Studio 16 2019" -A "ARM64"' +# } +# - { +# os: Windows, +# arch: arm, +# cmake: '-G "Visual Studio 16 2019" -A "ARM"' +# } steps: - uses: actions/checkout@v4 - name: Build libraries @@ -142,8 +146,10 @@ jobs: with: name: Linux-x86_64-libraries path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ - - name: Download model + - name: Download text generation model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - uses: actions/setup-java@v4 with: distribution: 'zulu' @@ -193,7 +199,7 @@ jobs: publish: if: ${{ github.event_name != 'workflow_dispatch' || github.event.inputs.build_only == 'no' }} - needs: [ test-linux,build-macos-native,build-win-native,build-linux-cuda ] + needs: [ test-linux,build-macos-native,build-win-native ] #,build-linux-cuda runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -202,10 +208,10 @@ jobs: pattern: "*-libraries" merge-multiple: true path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ - - uses: actions/download-artifact@v4 - with: - name: linux-libraries-cuda - path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/ +# - uses: actions/download-artifact@v4 +# with: +# name: linux-libraries-cuda +# path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/ - name: Set up Maven Central Repository uses: actions/setup-java@v3 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index 2278d454..96c62950 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4831 + GIT_TAG b4916 ) FetchContent_MakeAvailable(llama.cpp) @@ -69,7 +69,7 @@ endif() # include jni.h and jni_md.h if(NOT DEFINED JNI_INCLUDE_DIRS) - if(OS_NAME MATCHES "^Linux" OR OS_NAME STREQUAL "Mac") + if(OS_NAME MATCHES "^Linux" OR OS_NAME STREQUAL "Mac" OR OS_NAME STREQUAL "Darwin") set(JNI_INCLUDE_DIRS .github/include/unix) elseif(OS_NAME STREQUAL "Windows") set(JNI_INCLUDE_DIRS .github/include/windows) diff --git a/README.md b/README.md index 971c06af..1bc278b1 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ ![Java 11+](https://img.shields.io/badge/Java-11%2B-informational) -![llama.cpp b3534](https://img.shields.io/badge/llama.cpp-%23b3534-informational) +![llama.cpp b4916](https://img.shields.io/badge/llama.cpp-%23b4916-informational) # Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) @@ -17,7 +17,7 @@ Inference of Meta's LLaMA model (and others) in pure C/C++. 3. [Android](#importing-in-android) > [!NOTE] -> Now with support for Llama 3, Phi-3, and flash attention +> Now with support for Gemma 3 ## Quick Start @@ -27,18 +27,7 @@ Access this library via Maven: de.kherud llama - 3.4.1 - -``` - -Bu default the default library artifact is built only with CPU inference support. To enable CUDA, use a `cuda12-linux-x86-64` maven classifier: - -```xml - - de.kherud - llama - 3.4.1 - cuda12-linux-x86-64 + 4.1.0 ``` @@ -50,11 +39,7 @@ We support CPU inference for the following platforms out of the box: - Linux x86-64, aarch64 - MacOS x86-64, aarch64 (M-series) -- Windows x86-64, x64, arm (32 bit) - -For GPU inference, we support: - -- Linux x86-64 with CUDA 12.1+ +- Windows x86-64, x64 If any of these match your platform, you can include the Maven dependency and get started. @@ -88,13 +73,9 @@ All compiled libraries will be put in a resources directory matching your platfo #### Library Location -This project has to load three shared libraries: - -- ggml -- llama -- jllama +This project has to load a single shared library `jllama`. -Note, that the file names vary between operating systems, e.g., `ggml.dll` on Windows, `libggml.so` on Linux, and `libggml.dylib` on macOS. +Note, that the file name varies between operating systems, e.g., `jllama.dll` on Windows, `jllama.so` on Linux, and `jllama.dylib` on macOS. The application will search in the following order in the following locations: @@ -105,14 +86,6 @@ The application will search in the following order in the following locations: - From the **JAR**: If any of the libraries weren't found yet, the application will try to use a prebuilt shared library. This of course only works for the [supported platforms](#no-setup-required) . -Not all libraries have to be in the same location. -For example, if you already have a llama.cpp and ggml version you can install them as a system library and rely on the jllama library from the JAR. -This way, you don't have to compile anything. - -#### CUDA - -On Linux x86-64 with CUDA 12.1+, the library assumes that your CUDA libraries are findable in `java.library.path`. If you have CUDA installed in a non-standard location, then point the `java.library.path` to the directory containing the `libcudart.so.12` library. - ## Documentation ### Example @@ -124,8 +97,8 @@ public class Example { public static void main(String... args) throws IOException { ModelParameters modelParams = new ModelParameters() - .setModelFilePath("/path/to/model.gguf") - .setNGpuLayers(43); + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf") + .setGpuLayers(43); String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + @@ -144,8 +117,8 @@ public class Example { InferenceParameters inferParams = new InferenceParameters(prompt) .setTemperature(0.7f) .setPenalizeNl(true) - .setMirostat(InferenceParameters.MiroStat.V2) - .setAntiPrompt("\n"); + .setMiroStat(MiroStat.V2) + .setStopStrings("User:"); for (LlamaOutput output : model.generate(inferParams)) { System.out.print(output); prompt += output; @@ -165,7 +138,7 @@ model to your prompt in order to extend the context. If there is repeated conten cache this, to improve performance. ```java -ModelParameters modelParams = new ModelParameters().setModelFilePath("/path/to/model.gguf"); +ModelParameters modelParams = new ModelParameters().setModel("/path/to/model.gguf"); InferenceParameters inferParams = new InferenceParameters("Tell me a joke."); try (LlamaModel model = new LlamaModel(modelParams)) { // Stream a response and access more information about each output. @@ -197,9 +170,8 @@ for every inference task. All non-specified options have sensible defaults. ```java ModelParameters modelParams = new ModelParameters() - .setModelFilePath("/path/to/model.gguf") - .setLoraAdapter("/path/to/lora/adapter") - .setLoraBase("/path/to/lora/base"); + .setModel("/path/to/model.gguf") + .addLoraAdapter("/path/to/lora/adapter"); String grammar = """ root ::= (expr "=" term "\\n")+ expr ::= term ([-+*/] term)* @@ -234,7 +206,7 @@ LlamaModel.setLogger(null, (level, message) -> {}); ## Importing in Android You can use this library in Android project. -1. Add java-llama.cpp as a submodule in your android `app` project directory +1. Add java-llama.cpp as a submodule in your an droid `app` project directory ```shell git submodule add https://github.com/kherud/java-llama.cpp ``` diff --git a/pom.xml b/pom.xml index c081e192..67b366ee 100644 --- a/pom.xml +++ b/pom.xml @@ -1,14 +1,16 @@ - 4.0.0 de.kherud llama - 4.0.0 + 4.2.0 jar ${project.groupId}:${project.artifactId} - Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++. + Java Bindings for llama.cpp - A Port of Facebook's LLaMA model + in C/C++. https://github.com/kherud/java-llama.cpp @@ -39,7 +41,8 @@ ossrh - https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ @@ -71,17 +74,21 @@ maven-compiler-plugin 3.13.0 - + gpu compile - compile + + compile + -h src/main/cpp - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda @@ -98,10 +105,12 @@ copy-resources - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda - ${basedir}/src/main/resources_linux_cuda/ + + ${basedir}/src/main/resources_linux_cuda/ **/*.* @@ -176,7 +185,8 @@ maven-jar-plugin 3.4.2 - + cuda package @@ -185,7 +195,8 @@ cuda12-linux-x86-64 - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 0db026ea..11c80ae0 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -452,16 +452,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo llama_init_dft.context.reset(); } - ctx_server->chat_templates = common_chat_templates_init(ctx_server->model, params.chat_template); - try { - common_chat_format_example(ctx_server->chat_templates.get(), params.use_jinja); - } catch (const std::exception &e) { - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This " - "may cause the model to output suboptimal responses\n", - __func__); - ctx_server->chat_templates = common_chat_templates_init(ctx_server->model, "chatml"); - } - // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, common_chat_templates_source(ctx_server->chat_templates.get()), @@ -634,7 +624,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, json error = nullptr; server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - ctx_server->queue_results.remove_waiting_task_id(id_task); json response_str = result->to_json(); if (result->is_error()) { @@ -644,6 +633,10 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, return nullptr; } + if (result->is_stop()) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + const auto out_res = result->to_json(); // Extract "embedding" as a vector of vectors (2D array) @@ -679,6 +672,102 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, return j_embedding; } +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, + jobjectArray documents) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + + if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) { + env->ThrowNew(c_llama_error, + "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); + return nullptr; + } + + const std::string prompt = parse_jstring(env, jprompt); + + const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true); + + json responses = json::array(); + + std::vector tasks; + const jsize amount_documents = env->GetArrayLength(documents); + auto *document_array = parse_string_array(env, documents, amount_documents); + auto document_vector = std::vector(document_array, document_array + amount_documents); + free_string_array(document_array, amount_documents); + + std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true); + + tasks.reserve(tokenized_docs.size()); + for (int i = 0; i < tokenized_docs.size(); i++) { + auto task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); + } + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + std::vector results(task_ids.size()); + + // Create a new HashMap instance + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + if (o_probabilities == nullptr) { + env->ThrowNew(c_llama_error, "Failed to create HashMap object."); + return nullptr; + } + + for (int i = 0; i < (int)task_ids.size(); i++) { + server_task_result_ptr result = ctx_server->queue_results.recv(task_ids); + if (result->is_error()) { + auto response = result->to_json()["message"].get(); + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + env->ThrowNew(c_llama_error, response.c_str()); + return nullptr; + } + + const auto out_res = result->to_json(); + + if (result->is_stop()) { + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + } + + int index = out_res["index"].get(); + float score = out_res["score"].get(); + std::string tok_str = document_vector[index]; + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + jobject jprob = env->NewObject(c_float, cc_float, score); + env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob); + env->DeleteLocalRef(jtok_str); + env->DeleteLocalRef(jprob); + } + jbyteArray jbytes = parse_jbytes(env, prompt); + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true); +} + +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + + std::string c_params = parse_jstring(env, jparams); + json data = json::parse(c_params); + + json templateData = + oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, + ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get()); + std::string tok_str = templateData.at("prompt"); + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + return jtok_str; +} + JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) @@ -761,4 +850,4 @@ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammar nlohmann::ordered_json c_schema_json = nlohmann::ordered_json::parse(c_schema); const std::string c_grammar = json_schema_to_grammar(c_schema_json); return parse_jbytes(env, c_grammar); -} \ No newline at end of file +} diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 63d95b71..dc17fa83 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -84,6 +84,20 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *, job */ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *, jclass, jstring); +/* + * Class: de_kherud_llama_LlamaModel + * Method: rerank + * Signature: (Ljava/lang/String;[Ljava/lang/String;)Lde/kherud/llama/LlamaOutput; + */ +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *, jobject, jstring, jobjectArray); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: applyTemplate + * Signature: (Ljava/lang/String;)Ljava/lang/String;; + */ +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *, jobject, jstring); + #ifdef __cplusplus } #endif diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 66169a83..9686f2af 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -3269,151 +3269,3 @@ struct server_context { }; } }; - -static void common_params_handle_model_default(std::string &model, const std::string &model_url, std::string &hf_repo, - std::string &hf_file, const std::string &hf_token) { - if (!hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (hf_file.empty()) { - if (model.empty()) { - auto auto_detected = common_get_hf_file(hf_repo, hf_token); - if (auto_detected.first.empty() || auto_detected.second.empty()) { - exit(1); // built without CURL, error message already printed - } - hf_repo = auto_detected.first; - hf_file = auto_detected.second; - } else { - hf_file = model; - } - } - // make sure model path is present (for caching purposes) - if (model.empty()) { - // this is to avoid different repo having same file name, or same file name in different subdirs - std::string filename = hf_repo + "_" + hf_file; - // to make sure we don't have any slashes in the filename - string_replace_all(filename, "/", "_"); - model = fs_get_cache_file(filename); - } - } else if (!model_url.empty()) { - if (model.empty()) { - auto f = string_split(model_url, '#').front(); - f = string_split(f, '?').front(); - model = fs_get_cache_file(string_split(f, '/').back()); - } - } else if (model.empty()) { - model = DEFAULT_MODEL_PATH; - } -} - -// parse the given jparams (see de.kherud.llama.args.ModelParameters#toString()) from JSON to the required C++ struct. -static void server_params_parse(json jparams, common_params ¶ms) { - common_params default_params; - - params.sampling.seed = json_value(jparams, "seed", default_params.sampling.seed); - params.cpuparams.n_threads = json_value(jparams, "n_threads", default_params.cpuparams.n_threads); - params.speculative.cpuparams.n_threads = - json_value(jparams, "n_threads_draft", default_params.speculative.cpuparams.n_threads); - params.cpuparams_batch.n_threads = json_value(jparams, "n_threads_batch", default_params.cpuparams_batch.n_threads); - params.speculative.cpuparams_batch.n_threads = - json_value(jparams, "n_threads_batch_draft", default_params.speculative.cpuparams_batch.n_threads); - params.n_predict = json_value(jparams, "n_predict", default_params.n_predict); - params.n_ctx = json_value(jparams, "n_ctx", default_params.n_ctx); - params.n_batch = json_value(jparams, "n_batch", default_params.n_batch); - params.n_ubatch = json_value(jparams, "n_ubatch", default_params.n_ubatch); - params.n_keep = json_value(jparams, "n_keep", default_params.n_keep); - - params.speculative.n_max = json_value(jparams, "n_draft", default_params.speculative.n_max); - params.speculative.n_min = json_value(jparams, "n_draft_min", default_params.speculative.n_min); - - params.n_chunks = json_value(jparams, "n_chunks", default_params.n_chunks); - params.n_parallel = json_value(jparams, "n_parallel", default_params.n_parallel); - params.n_sequences = json_value(jparams, "n_sequences", default_params.n_sequences); - params.speculative.p_split = json_value(jparams, "p_split", default_params.speculative.p_split); - params.grp_attn_n = json_value(jparams, "grp_attn_n", default_params.grp_attn_n); - params.grp_attn_w = json_value(jparams, "grp_attn_w", default_params.grp_attn_w); - params.n_print = json_value(jparams, "n_print", default_params.n_print); - params.rope_freq_base = json_value(jparams, "rope_freq_base", default_params.rope_freq_base); - params.rope_freq_scale = json_value(jparams, "rope_freq_scale", default_params.rope_freq_scale); - params.yarn_ext_factor = json_value(jparams, "yarn_ext_factor", default_params.yarn_ext_factor); - params.yarn_attn_factor = json_value(jparams, "yarn_attn_factor", default_params.yarn_attn_factor); - params.yarn_beta_fast = json_value(jparams, "yarn_beta_fast", default_params.yarn_beta_fast); - params.yarn_beta_slow = json_value(jparams, "yarn_beta_slow", default_params.yarn_beta_slow); - params.yarn_orig_ctx = json_value(jparams, "yarn_orig_ctx", default_params.yarn_orig_ctx); - params.defrag_thold = json_value(jparams, "defrag_thold", default_params.defrag_thold); - params.numa = json_value(jparams, "numa", default_params.numa); - params.rope_scaling_type = json_value(jparams, "rope_scaling_type", default_params.rope_scaling_type); - params.pooling_type = json_value(jparams, "pooling_type", default_params.pooling_type); - params.model = json_value(jparams, "model", default_params.model); - params.speculative.model = json_value(jparams, "model_draft", default_params.speculative.model); - params.model_alias = json_value(jparams, "model_alias", default_params.model_alias); - params.model_url = json_value(jparams, "model_url", default_params.model_url); - params.hf_repo = json_value(jparams, "hf_repo", default_params.hf_repo); - params.hf_file = json_value(jparams, "hf_file", default_params.hf_file); - params.prompt = json_value(jparams, "prompt", default_params.prompt); - params.prompt_file = json_value(jparams, "prompt_file", default_params.prompt_file); - params.path_prompt_cache = json_value(jparams, "path_prompt_cache", default_params.path_prompt_cache); - params.input_prefix = json_value(jparams, "input_prefix", default_params.input_prefix); - params.input_suffix = json_value(jparams, "input_suffix", default_params.input_suffix); - params.antiprompt = json_value(jparams, "antiprompt", default_params.antiprompt); - params.lookup_cache_static = json_value(jparams, "lookup_cache_static", default_params.lookup_cache_static); - params.lookup_cache_dynamic = json_value(jparams, "lookup_cache_dynamic", default_params.lookup_cache_dynamic); - params.logits_file = json_value(jparams, "logits_file", default_params.logits_file); - // params.lora_adapters = json_value(jparams, "lora_adapter", default_params.lora_adapters); - params.embedding = json_value(jparams, "embedding", default_params.embedding); - params.escape = json_value(jparams, "escape", default_params.escape); - params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching); - params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn); - params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos); - params.sampling.ignore_eos = json_value(jparams, "ignore_eos", default_params.sampling.ignore_eos); - params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap); - params.use_mlock = json_value(jparams, "use_mlock", default_params.use_mlock); - params.no_kv_offload = json_value(jparams, "no_kv_offload", default_params.no_kv_offload); - params.chat_template = json_value(jparams, "chat_template", default_params.chat_template); - - if (jparams.contains("n_gpu_layers")) { - if (llama_supports_gpu_offload()) { - params.n_gpu_layers = json_value(jparams, "n_gpu_layers", default_params.n_gpu_layers); - params.speculative.n_gpu_layers = - json_value(jparams, "n_gpu_layers_draft", default_params.speculative.n_gpu_layers); - } else { - SRV_WRN("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " - "See main README.md for information on enabling GPU BLAS support: %s = %d", - "n_gpu_layers", params.n_gpu_layers); - } - } - - if (jparams.contains("split_mode")) { - params.split_mode = json_value(jparams, "split_mode", default_params.split_mode); -// todo: the definition checks here currently don't work due to cmake visibility reasons -#ifndef GGML_USE_CUDA - fprintf(stderr, "warning: llama.cpp was compiled without CUDA. Setting the split mode has no effect.\n"); -#endif - } - - if (jparams.contains("tensor_split")) { -#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL) - std::vector tensor_split = jparams["tensor_split"].get>(); - GGML_ASSERT(tensor_split.size() <= llama_max_devices()); - - for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) { - if (i_device < tensor_split.size()) { - params.tensor_split[i_device] = tensor_split.at(i_device); - } else { - params.tensor_split[i_device] = 0.0f; - } - } -#else - SRV_WRN("%s", "llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n"); -#endif // GGML_USE_CUDA - } - - if (jparams.contains("main_gpu")) { -#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL) - params.main_gpu = json_value(jparams, "main_gpu", default_params.main_gpu); -#else - SRV_WRN("%s", "llama.cpp was compiled without CUDA. It is not possible to set a main GPU."); -#endif - } - - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); -} diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index 0ac1b1dc..41f74cc9 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -1,6 +1,7 @@ package de.kherud.llama; import java.util.Collection; +import java.util.List; import java.util.Map; import de.kherud.llama.args.MiroStat; @@ -11,6 +12,7 @@ * and * {@link LlamaModel#complete(InferenceParameters)}. */ +@SuppressWarnings("unused") public final class InferenceParameters extends JsonParameters { private static final String PARAM_PROMPT = "prompt"; @@ -47,6 +49,7 @@ public final class InferenceParameters extends JsonParameters { private static final String PARAM_STREAM = "stream"; private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template"; private static final String PARAM_USE_JINJA = "use_jinja"; + private static final String PARAM_MESSAGES = "messages"; public InferenceParameters(String prompt) { // we always need a prompt @@ -480,21 +483,64 @@ public InferenceParameters setSamplers(Sampler... samplers) { return this; } - InferenceParameters setStream(boolean stream) { - parameters.put(PARAM_STREAM, String.valueOf(stream)); - return this; - } - /** - * Set whether or not generate should apply a chat template (default: false) + * Set whether generate should apply a chat template (default: false) */ public InferenceParameters setUseChatTemplate(boolean useChatTemplate) { parameters.put(PARAM_USE_JINJA, String.valueOf(useChatTemplate)); return this; } - - - + /** + * Set the messages for chat-based inference. + * - Allows **only one** system message. + * - Allows **one or more** user/assistant messages. + */ + public InferenceParameters setMessages(String systemMessage, List> messages) { + StringBuilder messagesBuilder = new StringBuilder(); + messagesBuilder.append("["); + + // Add system message (if provided) + if (systemMessage != null && !systemMessage.isEmpty()) { + messagesBuilder.append("{\"role\": \"system\", \"content\": ") + .append(toJsonString(systemMessage)) + .append("}"); + if (!messages.isEmpty()) { + messagesBuilder.append(", "); + } + } + + // Add user/assistant messages + for (int i = 0; i < messages.size(); i++) { + Pair message = messages.get(i); + String role = message.getKey(); + String content = message.getValue(); + + if (!role.equals("user") && !role.equals("assistant")) { + throw new IllegalArgumentException("Invalid role: " + role + ". Role must be 'user' or 'assistant'."); + } + + messagesBuilder.append("{\"role\":") + .append(toJsonString(role)) + .append(", \"content\": ") + .append(toJsonString(content)) + .append("}"); + + if (i < messages.size() - 1) { + messagesBuilder.append(", "); + } + } + + messagesBuilder.append("]"); + + // Convert ArrayNode to a JSON string and store it in parameters + parameters.put(PARAM_MESSAGES, messagesBuilder.toString()); + return this; + } + + InferenceParameters setStream(boolean stream) { + parameters.put(PARAM_STREAM, String.valueOf(stream)); + return this; + } } diff --git a/src/main/java/de/kherud/llama/LlamaIterator.java b/src/main/java/de/kherud/llama/LlamaIterator.java index fdff993b..cb1c5c2c 100644 --- a/src/main/java/de/kherud/llama/LlamaIterator.java +++ b/src/main/java/de/kherud/llama/LlamaIterator.java @@ -35,6 +35,9 @@ public LlamaOutput next() { } LlamaOutput output = model.receiveCompletion(taskId); hasNext = !output.stop; + if (output.stop) { + model.releaseTask(taskId); + } return output; } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 7749b321..eab36202 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -5,6 +5,9 @@ import java.lang.annotation.Native; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.function.BiConsumer; /** @@ -130,11 +133,39 @@ public void close() { private native void delete(); - private native void releaseTask(int taskId); + native void releaseTask(int taskId); private static native byte[] jsonSchemaToGrammarBytes(String schema); public static String jsonSchemaToGrammar(String schema) { return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8); } + + public List> rerank(boolean reRank, String query, String ... documents) { + LlamaOutput output = rerank(query, documents); + + Map scoredDocumentMap = output.probabilities; + + List> rankedDocuments = new ArrayList<>(); + + if (reRank) { + // Sort in descending order based on Float values + scoredDocumentMap.entrySet() + .stream() + .sorted((a, b) -> Float.compare(b.getValue(), a.getValue())) // Descending order + .forEach(entry -> rankedDocuments.add(new Pair<>(entry.getKey(), entry.getValue()))); + } else { + // Copy without sorting + scoredDocumentMap.forEach((key, value) -> rankedDocuments.add(new Pair<>(key, value))); + } + + return rankedDocuments; + } + + public native LlamaOutput rerank(String query, String... documents); + + public String applyTemplate(InferenceParameters parameters) { + return applyTemplate(parameters.toString()); + } + public native String applyTemplate(String parametersJson); } diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 8615bd50..7999295d 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -459,7 +459,7 @@ public ModelParameters setJsonSchema(String schema) { * Set pooling type for embeddings (default: model default if unspecified). */ public ModelParameters setPoolingType(PoolingType type) { - parameters.put("--pooling", String.valueOf(type.getId())); + parameters.put("--pooling", type.getArgValue()); return this; } @@ -467,7 +467,7 @@ public ModelParameters setPoolingType(PoolingType type) { * Set RoPE frequency scaling method (default: linear unless specified by the model). */ public ModelParameters setRopeScaling(RopeScalingType type) { - parameters.put("--rope-scaling", String.valueOf(type.getId())); + parameters.put("--rope-scaling", type.getArgValue()); return this; } @@ -584,7 +584,7 @@ public ModelParameters setCacheTypeV(CacheType type) { } /** - * Set KV cache defragmentation threshold (default: 0.1, < 0 - disabled). + * Set KV cache defragmentation threshold (default: 0.1, < 0 - disabled). */ public ModelParameters setDefragThold(float defragThold) { parameters.put("--defrag-thold", String.valueOf(defragThold)); @@ -640,7 +640,7 @@ public ModelParameters setNuma(NumaStrategy numaStrategy) { } /** - * Set comma-separated list of devices to use for offloading (none = don't offload). + * Set comma-separated list of devices to use for offloading <dev1,dev2,..> (none = don't offload). */ public ModelParameters setDevices(String devices) { parameters.put("--device", devices); @@ -960,3 +960,5 @@ public ModelParameters enableJinja() { } } + + diff --git a/src/main/java/de/kherud/llama/OSInfo.java b/src/main/java/de/kherud/llama/OSInfo.java index 772aeaef..9354ec2f 100644 --- a/src/main/java/de/kherud/llama/OSInfo.java +++ b/src/main/java/de/kherud/llama/OSInfo.java @@ -200,7 +200,7 @@ else if (armType.startsWith("aarch64")) { } // Java 1.8 introduces a system property to determine armel or armhf - // http://bugs.java.com/bugdatabase/view_bug.do?bug_id=8005545 + // https://bugs.openjdk.org/browse/JDK-8005545 String abi = System.getProperty("sun.arch.abi"); if (abi != null && abi.startsWith("gnueabihf")) { return "armv7"; diff --git a/src/main/java/de/kherud/llama/Pair.java b/src/main/java/de/kherud/llama/Pair.java new file mode 100644 index 00000000..48ac648b --- /dev/null +++ b/src/main/java/de/kherud/llama/Pair.java @@ -0,0 +1,48 @@ +package de.kherud.llama; + +import java.util.Objects; + +public class Pair { + + private final K key; + private final V value; + + public Pair(K key, V value) { + this.key = key; + this.value = value; + } + + public K getKey() { + return key; + } + + public V getValue() { + return value; + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Pair other = (Pair) obj; + return Objects.equals(key, other.key) && Objects.equals(value, other.value); + } + + @Override + public String toString() { + return "Pair [key=" + key + ", value=" + value + "]"; + } + + + + +} diff --git a/src/main/java/de/kherud/llama/args/PoolingType.java b/src/main/java/de/kherud/llama/args/PoolingType.java index a9c9dbae..c0379c85 100644 --- a/src/main/java/de/kherud/llama/args/PoolingType.java +++ b/src/main/java/de/kherud/llama/args/PoolingType.java @@ -2,20 +2,20 @@ public enum PoolingType { - UNSPECIFIED(-1), - NONE(0), - MEAN(1), - CLS(2), - LAST(3), - RANK(4); + UNSPECIFIED("unspecified"), + NONE("none"), + MEAN("mean"), + CLS("cls"), + LAST("last"), + RANK("rank"); - private final int id; + private final String argValue; - PoolingType(int value) { - this.id = value; + PoolingType(String value) { + this.argValue = value; } - public int getId() { - return id; + public String getArgValue() { + return argValue; } -} +} \ No newline at end of file diff --git a/src/main/java/de/kherud/llama/args/RopeScalingType.java b/src/main/java/de/kherud/llama/args/RopeScalingType.java index eed939a1..138d05be 100644 --- a/src/main/java/de/kherud/llama/args/RopeScalingType.java +++ b/src/main/java/de/kherud/llama/args/RopeScalingType.java @@ -2,20 +2,20 @@ public enum RopeScalingType { - UNSPECIFIED(-1), - NONE(0), - LINEAR(1), - YARN2(2), - LONGROPE(3), - MAX_VALUE(3); + UNSPECIFIED("unspecified"), + NONE("none"), + LINEAR("linear"), + YARN2("yarn"), + LONGROPE("longrope"), + MAX_VALUE("maxvalue"); - private final int id; + private final String argValue; - RopeScalingType(int value) { - this.id = value; + RopeScalingType(String value) { + this.argValue = value; } - public int getId() { - return id; + public String getArgValue() { + return argValue; } -} +} \ No newline at end of file diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index f2e931b4..e3e69d8c 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -158,6 +158,26 @@ public void testEmbedding() { float[] embedding = model.embed(prefix); Assert.assertEquals(4096, embedding.length); } + + + @Ignore + /** + * To run this test download the model from here https://huggingface.co/mradermacher/jina-reranker-v1-tiny-en-GGUF/tree/main + * remove .enableEmbedding() from model setup and add .enableReRanking() and then enable the test. + */ + public void testReRanking() { + + String query = "Machine learning is"; + String [] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + }; + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], TEST_DOCUMENTS[3] ); + + System.out.println(llamaOutput); + } @Test public void testTokenization() { @@ -296,4 +316,20 @@ public void testJsonSchemaToGrammar() { String actualGrammar = LlamaModel.jsonSchemaToGrammar(schema); Assert.assertEquals(expectedGrammar, actualGrammar); } + + @Test + public void testTemplate() { + + List> userMessages = new ArrayList<>(); + userMessages.add(new Pair<>("user", "What is the best book?")); + userMessages.add(new Pair<>("assistant", "It depends on your interests. Do you like fiction or non-fiction?")); + + InferenceParameters params = new InferenceParameters("A book recommendation system.") + .setMessages("Book", userMessages) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setSeed(42); + Assert.assertEquals(model.applyTemplate(params), "<|im_start|>system\nBook<|im_end|>\n<|im_start|>user\nWhat is the best book?<|im_end|>\n<|im_start|>assistant\nIt depends on your interests. Do you like fiction or non-fiction?<|im_end|>\n<|im_start|>assistant\n"); + } } diff --git a/src/test/java/de/kherud/llama/RerankingModelTest.java b/src/test/java/de/kherud/llama/RerankingModelTest.java new file mode 100644 index 00000000..60d32bde --- /dev/null +++ b/src/test/java/de/kherud/llama/RerankingModelTest.java @@ -0,0 +1,83 @@ +package de.kherud.llama; + +import java.util.List; +import java.util.Map; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class RerankingModelTest { + + private static LlamaModel model; + + String query = "Machine learning is"; + String[] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." }; + + @BeforeClass + public static void setup() { + model = new LlamaModel( + new ModelParameters().setCtxSize(128).setModel("models/jina-reranker-v1-tiny-en-Q4_0.gguf") + .setGpuLayers(43).enableReranking().enableLogTimestamps().enableLogPrefix()); + } + + @AfterClass + public static void tearDown() { + if (model != null) { + model.close(); + } + } + + @Test + public void testReRanking() { + + + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], + TEST_DOCUMENTS[3]); + + Map rankedDocumentsMap = llamaOutput.probabilities; + Assert.assertTrue(rankedDocumentsMap.size()==TEST_DOCUMENTS.length); + + // Finding the most and least relevant documents + String mostRelevantDoc = null; + String leastRelevantDoc = null; + float maxScore = Float.MIN_VALUE; + float minScore = Float.MAX_VALUE; + + for (Map.Entry entry : rankedDocumentsMap.entrySet()) { + if (entry.getValue() > maxScore) { + maxScore = entry.getValue(); + mostRelevantDoc = entry.getKey(); + } + if (entry.getValue() < minScore) { + minScore = entry.getValue(); + leastRelevantDoc = entry.getKey(); + } + } + + // Assertions + Assert.assertTrue(maxScore > minScore); + Assert.assertEquals("Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", mostRelevantDoc); + Assert.assertEquals("Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.", leastRelevantDoc); + + + } + + @Test + public void testSortedReRanking() { + List> rankedDocuments = model.rerank(true, query, TEST_DOCUMENTS); + Assert.assertEquals(rankedDocuments.size(), TEST_DOCUMENTS.length); + + // Check the ranking order: each score should be >= the next one + for (int i = 0; i < rankedDocuments.size() - 1; i++) { + float currentScore = rankedDocuments.get(i).getValue(); + float nextScore = rankedDocuments.get(i + 1).getValue(); + Assert.assertTrue("Ranking order incorrect at index " + i, currentScore >= nextScore); + } + } +}