diff --git a/Package.swift b/Package.swift index b285f1b..032458a 100644 --- a/Package.swift +++ b/Package.swift @@ -27,8 +27,10 @@ let package = Package( .library(name: "SpeziLLMFog", targets: ["SpeziLLMFog"]) ], dependencies: [ + .package(url: "https://github.com/ml-explore/mlx-swift", from: "0.18.1"), + .package(url: "https://github.com/ml-explore/mlx-swift-examples", from: "1.16.0"), + .package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.12")), .package(url: "https://github.com/StanfordBDHG/OpenAI", .upToNextMinor(from: "0.2.9")), - .package(url: "https://github.com/StanfordBDHG/llama.cpp", .upToNextMinor(from: "0.3.3")), .package(url: "https://github.com/StanfordSpezi/Spezi", from: "1.2.1"), .package(url: "https://github.com/StanfordSpezi/SpeziFoundation", from: "2.0.0-beta.3"), .package(url: "https://github.com/StanfordSpezi/SpeziStorage", from: "1.0.2"), @@ -49,19 +51,24 @@ let package = Package( name: "SpeziLLMLocal", dependencies: [ .target(name: "SpeziLLM"), - .product(name: "llama", package: "llama.cpp"), .product(name: "SpeziFoundation", package: "SpeziFoundation"), - .product(name: "Spezi", package: "Spezi") - ], - swiftSettings: [ - .interoperabilityMode(.Cxx) + .product(name: "Spezi", package: "Spezi"), + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXFast", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift"), + .product(name: "MLXOptimizers", package: "mlx-swift"), + .product(name: "MLXRandom", package: "mlx-swift"), + .product(name: "Transformers", package: "swift-transformers"), + .product(name: "LLM", package: "mlx-swift-examples") ] ), .target( name: "SpeziLLMLocalDownload", dependencies: [ .product(name: "SpeziOnboarding", package: "SpeziOnboarding"), - .product(name: "SpeziViews", package: "SpeziViews") + .product(name: "SpeziViews", package: "SpeziViews"), + .target(name: "SpeziLLMLocal"), + .product(name: "LLM", package: "mlx-swift-examples") ] ), .target( diff --git a/README.md b/README.md index f173a3a..2de7818 100644 --- a/README.md +++ b/README.md @@ -57,37 +57,13 @@ The section below highlights the setup and basic use of the [SpeziLLMLocal](http ### Spezi LLM Local -The target enables developers to easily execute medium-size Language Models (LLMs) locally on-device via the [llama.cpp framework](https://github.com/ggerganov/llama.cpp). The module allows you to interact with the locally run LLM via purely Swift-based APIs, no interaction with low-level C or C++ code is necessary, building on top of the infrastructure of the [SpeziLLM target](https://swiftpackageindex.com/stanfordspezi/spezillm/documentation/spezillm). +The target enables developers to easily execute medium-size Language Models (LLMs) locally on-device. The module allows you to interact with the locally run LLM via purely Swift-based APIs, no interaction with low-level code is necessary, building on top of the infrastructure of the [SpeziLLM target](https://swiftpackageindex.com/stanfordspezi/spezillm/documentation/spezillm). + +> [!IMPORTANT] +> Spezi LLM Local is not compatible with simulators. The underlying [`mlx-swift`](https://github.com/ml-explore/mlx-swift) requires a modern Metal MTLGPUFamily and the simulator does not provide that. > [!IMPORTANT] -> Important: In order to use the LLM local target, one needs to set build parameters in the consuming Xcode project or the consuming SPM package to enable the [Swift / C++ Interop](https://www.swift.org/documentation/cxx-interop/), introduced in Xcode 15 and Swift 5.9. Keep in mind that this is true for nested dependencies, one needs to set this configuration recursivly for the entire dependency tree towards the llama.cpp SPM package. -> -> **For Xcode projects:** -> - Open your [build settings in Xcode](https://developer.apple.com/documentation/xcode/configuring-the-build-settings-of-a-target/) by selecting *PROJECT_NAME > TARGET_NAME > Build Settings*. -> - Within the *Build Settings*, search for the `C++ and Objective-C Interoperability` setting and set it to `C++ / Objective-C++`. This enables the project to use the C++ headers from llama.cpp. -> -> **For SPM packages:** -> - Open the `Package.swift` file of your [SPM package]((https://www.swift.org/documentation/package-manager/)) -> - Within the package `target` that consumes the llama.cpp package, add the `interoperabilityMode(_:)` Swift build setting like that: -> ```swift -> /// Adds the dependency to the Spezi LLM SPM package -> dependencies: [ -> .package(url: "https://github.com/StanfordSpezi/SpeziLLM", .upToNextMinor(from: "0.6.0")) -> ], -> targets: [ -> .target( -> name: "ExampleConsumingTarget", -> /// State the dependence of the target to SpeziLLMLocal -> dependencies: [ -> .product(name: "SpeziLLMLocal", package: "SpeziLLM") -> ], -> /// Important: Configure the `.interoperabilityMode(_:)` within the `swiftSettings` -> swiftSettings: [ -> .interoperabilityMode(.Cxx) -> ] -> ) -> ] -> ``` +> Important: To use the LLM local target, some LLMs require adding the [Increase Memory Limit](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_increased-memory-limit) entitlement to the project. #### Setup @@ -123,7 +99,8 @@ struct LLMLocalDemoView: View { // Instantiate the `LLMLocalSchema` to an `LLMLocalSession` via the `LLMRunner`. let llmSession: LLMLocalSession = runner( with: LLMLocalSchema( - modelPath: URL(string: "URL to the local model file")! + model: .llama3_8B_4bit, + formatChat: LLMLocalSchema.PromptFormattingDefaults.llama3 ) ) diff --git a/Sources/SpeziLLMLocal/Configuration/LLMLocalContextParameters.swift b/Sources/SpeziLLMLocal/Configuration/LLMLocalContextParameters.swift index e4ad4a9..a707e83 100644 --- a/Sources/SpeziLLMLocal/Configuration/LLMLocalContextParameters.swift +++ b/Sources/SpeziLLMLocal/Configuration/LLMLocalContextParameters.swift @@ -1,218 +1,27 @@ // // This source file is part of the Stanford Spezi open source project // -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) // // SPDX-License-Identifier: MIT // import Foundation -@preconcurrency import llama /// Represents the context parameters of the LLM. -/// -/// Internally, these data points are passed as a llama.cpp `llama_context_params` C struct to the LLM. public struct LLMLocalContextParameters: Sendable { - // swiftlint:disable identifier_name - /// Swift representation of the `ggml_type` of llama.cpp, indicating data types within KV caches. - public enum GGMLType: UInt32 { - case f32 = 0 - case f16 - case q4_0 - case q4_1 - case q5_0 = 6 - case q5_1 - case q8_0 - case q8_1 - /// k-quantizations - case q2_k - case q3_k - case q4_k - case q5_k - case q6_k - case q8_k - case iq2_xxs - case iq2_xs - case i8 - case i16 - case i32 - } - // swiftlint:enable identifier_name - - - /// Wrapped C struct from the llama.cpp library, later-on passed to the LLM - private var wrapped: llama_context_params - - - /// Context parameters in llama.cpp's low-level C representation - var llamaCppRepresentation: llama_context_params { - wrapped - } - /// RNG seed of the LLM - var seed: UInt32 { - get { - wrapped.seed - } - set { - wrapped.seed = newValue - } - } - - /// Context window size in tokens (0 = take default window size from model) - var contextWindowSize: UInt32 { - get { - wrapped.n_ctx - } - set { - wrapped.n_ctx = newValue - } - } - - /// Maximum batch size during prompt processing - var batchSize: UInt32 { - get { - wrapped.n_batch - } - set { - wrapped.n_batch = newValue - } - } - - /// Number of threads used by LLM for generation of output - var threadCount: UInt32 { - get { - wrapped.n_threads - } - set { - wrapped.n_threads = newValue - } - } - - /// Number of threads used by LLM for batch processing - var threadCountBatch: UInt32 { - get { - wrapped.n_threads_batch - } - set { - wrapped.n_threads_batch = newValue - } - } - - /// RoPE base frequency (0 = take default from model) - var ropeFreqBase: Float { - get { - wrapped.rope_freq_base - } - set { - wrapped.rope_freq_base = newValue - } - } - - /// RoPE frequency scaling factor (0 = take default from model) - var ropeFreqScale: Float { - get { - wrapped.rope_freq_scale - } - set { - wrapped.rope_freq_scale = newValue - } - } - - /// If `true`, offload the KQV ops (including the KV cache) to GPU - var offloadKQV: Bool { - get { - wrapped.offload_kqv - } - set { - wrapped.offload_kqv = newValue - } - } - - /// ``GGMLType`` of the key of the KV cache - var kvKeyType: GGMLType { - get { - GGMLType(rawValue: wrapped.type_k.rawValue) ?? .f16 - } - set { - wrapped.type_k = ggml_type(rawValue: newValue.rawValue) - } - } - - /// ``GGMLType`` of the value of the KV cache - var kvValueType: GGMLType { - get { - GGMLType(rawValue: wrapped.type_v.rawValue) ?? .f16 - } - set { - wrapped.type_v = ggml_type(rawValue: newValue.rawValue) - } - } - - /// If `true`, the (deprecated) `llama_eval()` call computes all logits, not just the last one - var computeAllLogits: Bool { - get { - wrapped.logits_all - } - set { - wrapped.logits_all = newValue - } - } - - /// If `true`, the mode is set to embeddings only - var embeddingsOnly: Bool { - get { - wrapped.embeddings - } - set { - wrapped.embeddings = newValue - } - } + var seed: UInt64? /// Creates the ``LLMLocalContextParameters`` which wrap the underlying llama.cpp `llama_context_params` C struct. /// Is passed to the underlying llama.cpp model in order to configure the context of the LLM. /// /// - Parameters: - /// - seed: RNG seed of the LLM, defaults to `4294967295` (which represents a random seed). - /// - contextWindowSize: Context window size in tokens, defaults to `1024`. - /// - batchSize: Maximum batch size during prompt processing, defaults to `1024` tokens. - /// - threadCount: Number of threads used by LLM for generation of output, defaults to the processor count of the device. - /// - threadCountBatch: Number of threads used by LLM for batch processing, defaults to the processor count of the device. - /// - ropeFreqBase: RoPE base frequency, defaults to `0` indicating the default from model. - /// - ropeFreqScale: RoPE frequency scaling factor, defaults to `0` indicating the default from model. - /// - offloadKQV: Offloads the KQV ops (including the KV cache) to GPU, defaults to `true`. - /// - kvKeyType: ``GGMLType`` of the key of the KV cache, defaults to ``GGMLType/f16``. - /// - kvValueType: ``GGMLType`` of the value of the KV cache, defaults to ``GGMLType/f16``. - /// - computeAllLogits: `llama_eval()` call computes all logits, not just the last one. Defaults to `false`. - /// - embeddingsOnly: Embedding-only mode, defaults to `false`. + /// - seed: RNG seed of the LLM, defaults to a random seed. public init( - seed: UInt32 = 4294967295, - contextWindowSize: UInt32 = 1024, - batchSize: UInt32 = 1024, - threadCount: UInt32 = .init(ProcessInfo.processInfo.processorCount), - threadCountBatch: UInt32 = .init(ProcessInfo.processInfo.processorCount), - ropeFreqBase: Float = 0.0, - ropeFreqScale: Float = 0.0, - offloadKQV: Bool = true, - kvKeyType: GGMLType = .f16, - kvValueType: GGMLType = .f16, - computeAllLogits: Bool = false, - embeddingsOnly: Bool = false + seed: UInt64? = nil ) { - self.wrapped = llama_context_default_params() - self.seed = seed - self.contextWindowSize = contextWindowSize - self.batchSize = batchSize - self.threadCount = threadCount - self.threadCountBatch = threadCountBatch - self.ropeFreqBase = ropeFreqBase - self.ropeFreqScale = ropeFreqScale - self.offloadKQV = offloadKQV - self.kvKeyType = kvKeyType - self.kvValueType = kvValueType - self.computeAllLogits = computeAllLogits - self.embeddingsOnly = embeddingsOnly } } diff --git a/Sources/SpeziLLMLocal/Configuration/LLMLocalModel.swift b/Sources/SpeziLLMLocal/Configuration/LLMLocalModel.swift new file mode 100644 index 0000000..0e6f9bc --- /dev/null +++ b/Sources/SpeziLLMLocal/Configuration/LLMLocalModel.swift @@ -0,0 +1,87 @@ +// +// This source file is part of the Stanford Spezi open source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + + +// swiftlint:disable identifier_name +/// Represents the available LLM models. +public enum LLMLocalModel { + /// Llama 3.1, 8 Billion Parameters, Instruct Mode, 4-bit Version + case llama3_1_8B_4bit + /// Llama 3, 8 Billion Parameters, Instruction-Tuned, 4-bit Version + case llama3_8B_4bit + /// Llama 3.2, 1 Billion Parameters, Instruction-Tuned, 4-bit Version + case llama3_2_1B_4bit + /// Llama 3.2, 3 Billion Parameters, Instruction-Tuned, 4-bit Version + case llama3_2_3B_4bit + /// Mistral Nemo, Instruction-Tuned, Model 2407, 4-bit Version + case mistralNeMo4bit + /// SmolLM, 135 Million Parameters, Instruction-Tuned, 4-bit Version + case smolLM_135M_4bit + /// Mistral, 7 Billion Parameters, Instruction-Tuned, Version 0.3, 4-bit Version + case mistral7B4bit + /// Code Llama, 13 Billion Parameters, Instruction-Tuned, Hugging Face Format, 4-bit, MLX Version + case codeLlama13b4bit + /// Phi 2, Hugging Face Format, 4-bit, MLX Version + case phi4bit + /// Phi 3 Mini, 4K Context Window, Instruction-Tuned, 4-bit Version, No Q-Embedding + case phi3_4bit + /// Phi 3.5 Mini, Instruction-Tuned, 4-bit Version + case phi3_5_4bit + /// Quantized Gemma, 2 Billion Parameters, Instruction-Tuned + case gemma2bQuantized + /// Gemma 2, 9 Billion Parameters, Instruction-Tuned, 4-bit Version + case gemma_2_9b_it_4bit + /// Gemma 2, 2 Billion Parameters, Instruction-Tuned, 4-bit Version + case gemma_2_2b_it_4bit + /// Qwen 1.5, 0.5 Billion Parameters, Chat-Tuned, 4-bit Version + case qwen205b4bit + /// OpenELM, 270 Million Parameters, Instruction-Tuned + case openelm270m4bit + /// Set the Huggingface ID of the model. e.g. "\/\" + case custom(id: String) + + /// The Huggingface ID for the model + public var hubID: String { + switch self { + case .llama3_1_8B_4bit: + return "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit" + case .llama3_8B_4bit: + return "mlx-community/Meta-Llama-3-8B-Instruct-4bit" + case .llama3_2_1B_4bit: + return "mlx-community/Llama-3.2-1B-Instruct-4bit" + case .llama3_2_3B_4bit: + return "mlx-community/Llama-3.2-3B-Instruct-4bit" + case .mistralNeMo4bit: + return "mlx-community/Mistral-Nemo-Instruct-2407-4bit" + case .smolLM_135M_4bit: + return "mlx-community/SmolLM-135M-Instruct-4bit" + case .mistral7B4bit: + return "mlx-community/Mistral-7B-Instruct-v0.3-4bit" + case .codeLlama13b4bit: + return "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX" + case .phi4bit: + return "mlx-community/phi-2-hf-4bit-mlx" + case .phi3_4bit: + return "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed" + case .phi3_5_4bit: + return "mlx-community/Phi-3.5-mini-instruct-4bit" + case .gemma2bQuantized: + return "mlx-community/quantized-gemma-2b-it" + case .gemma_2_9b_it_4bit: + return "mlx-community/gemma-2-9b-it-4bit" + case .gemma_2_2b_it_4bit: + return "mlx-community/gemma-2-2b-it-4bit" + case .qwen205b4bit: + return "mlx-community/Qwen1.5-0.5B-Chat-4bit" + case .openelm270m4bit: + return "mlx-community/OpenELM-270M-Instruct" + case .custom(let id): + return id + } + } +} diff --git a/Sources/SpeziLLMLocal/Configuration/LLMLocalParameters.swift b/Sources/SpeziLLMLocal/Configuration/LLMLocalParameters.swift index 2d5e8e5..01c9fcf 100644 --- a/Sources/SpeziLLMLocal/Configuration/LLMLocalParameters.swift +++ b/Sources/SpeziLLMLocal/Configuration/LLMLocalParameters.swift @@ -7,17 +7,9 @@ // import Foundation -@preconcurrency import llama - /// Represents the parameters of the LLM. -/// -/// Internally, these data points are passed as a llama.cpp `llama_model_params` C struct to the LLM. public struct LLMLocalParameters: Sendable { - /// Typealias for an internal llama.cpp progress callback function - public typealias LlamaProgressCallback = (@convention(c) (Float, UnsafeMutableRawPointer?) -> Bool) - - /// Defaults of possible LLMs parameter settings. public enum Defaults { /// Default system prompt for local LLMs. @@ -31,100 +23,10 @@ public struct LLMLocalParameters: Sendable { let systemPrompt: String? /// Indicates the maximum output length generated by the LLM. let maxOutputLength: Int - /// Indicates whether the BOS token is added by the LLM. If `nil`, the default from the model itself is taken. - let addBosToken: Bool - - - /// Wrapped C struct from the llama.cpp library, later-on passed to the LLM - private var wrapped: llama_model_params - - - /// Model parameters in llama.cpp's low-level C representation - var llamaCppRepresentation: llama_model_params { - wrapped - } - - /// Number of layers to store in VRAM - /// - Note: On iOS simulators, this property has to be set to 0 (which is automatically done by the library). - var gpuLayerCount: Int32 { - get { - wrapped.n_gpu_layers - } - set { - wrapped.n_gpu_layers = newValue - } - } - - /// Indicates the GPU that is used for scratch and small tensors. - var mainGpu: Int32 { - get { - wrapped.main_gpu - } - set { - wrapped.main_gpu = newValue - } - } - - /// Indicates how to split layers across multiple GPUs. - var tensorSplit: UnsafePointer? { - get { - wrapped.tensor_split - } - set { - wrapped.tensor_split = newValue - } - } - - /// Progress callback called with a progress value between 0 and 1 - var progressCallback: LlamaProgressCallback? { - get { - wrapped.progress_callback - } - set { - wrapped.progress_callback = newValue - } - } - - /// Context pointer that is passed to the progress callback - var progressCallbackUserData: UnsafeMutableRawPointer? { - get { - wrapped.progress_callback_user_data - } - set { - wrapped.progress_callback_user_data = newValue - } - } - - /// Indicates wether booleans should be kept together to avoid misalignment during copy-by-value. - var vocabOnly: Bool { - get { - wrapped.vocab_only - } - set { - wrapped.vocab_only = newValue - } - } - - /// Indicates if mmap should be used. - var useMmap: Bool { - get { - wrapped.use_mmap - } - set { - wrapped.use_mmap = newValue - } - } - - /// Forces the system to keep model in RAM. - var useMlock: Bool { - get { - wrapped.use_mlock - } - set { - wrapped.use_mlock = newValue - } - } + let extraEOSTokens: Set + /// Interval for displaying output after every N tokens generated. + let displayEveryNTokens: Int /// Creates the ``LLMLocalParameters`` which wrap the underlying llama.cpp `llama_model_params` C struct. /// Is passed to the underlying llama.cpp model in order to configure the LLM. @@ -132,46 +34,17 @@ public struct LLMLocalParameters: Sendable { /// - Parameters: /// - systemPrompt: The to-be-used system prompt of the LLM enabling fine-tuning of the LLMs behaviour. Defaults to the regular default chat-based LLM system prompt. /// - maxOutputLength: The maximum output length generated by the Spezi LLM, defaults to `512`. - /// - addBosToken: Indicates wether the BOS token is added by the Spezi LLM, defaults to `false`. - /// - gpuLayerCount: Number of layers to store in VRAM, defaults to `1`, meaning Apple's `Metal` framework is enabled. - /// - mainGpu: GPU that is used for scratch and small tensors, defaults to `0` representing the main GPU. - /// - tensorSplit: Split layers across multiple GPUs, defaults to `nil`, meaning no split. - /// - progressCallback: Progress callback called with a progress value between 0 and 1, defaults to `nil`. - /// - progressCallbackUserData: Context pointer that is passed to the progress callback, defaults to `nil`. - /// - vocabOnly: Indicates wether booleans should be kept together to avoid misalignment during copy-by-value., defaults to `false`. - /// - useMmap: Indicates if mmap should be used., defaults to `true`. - /// - useMlock: Forces the system to keep model in RAM, defaults to `false`. + /// - extraEOSTokens: Additional tokens to use for end of string + /// - displayEveryNTokens: Interval for displaying output after every N tokens generated, defaults to `4`. public init( systemPrompt: String? = Defaults.defaultSystemPrompt, maxOutputLength: Int = 512, - addBosToken: Bool = false, - gpuLayerCount: Int32 = 1, - mainGpu: Int32 = 0, - tensorSplit: UnsafePointer? = nil, - progressCallback: LlamaProgressCallback? = nil, - progressCallbackUserData: UnsafeMutableRawPointer? = nil, - vocabOnly: Bool = false, - useMmap: Bool = true, - useMlock: Bool = false + extraEOSTokens: Set = [], + displayEveryNTokens: Int = 4 ) { - self.wrapped = llama_model_default_params() - self.systemPrompt = systemPrompt self.maxOutputLength = maxOutputLength - self.addBosToken = addBosToken - - /// Overwrite `gpuLayerCount` in case of a simulator target environment - #if targetEnvironment(simulator) - self.gpuLayerCount = 0 // Disable Metal on simulator as crash otherwise - #else - self.gpuLayerCount = gpuLayerCount - #endif - self.mainGpu = mainGpu - self.tensorSplit = tensorSplit - self.progressCallback = progressCallback - self.progressCallbackUserData = progressCallbackUserData - self.vocabOnly = vocabOnly - self.useMmap = useMmap - self.useMlock = useMlock + self.extraEOSTokens = extraEOSTokens + self.displayEveryNTokens = displayEveryNTokens } } diff --git a/Sources/SpeziLLMLocal/Configuration/LLMLocalPlatformConfiguration.swift b/Sources/SpeziLLMLocal/Configuration/LLMLocalPlatformConfiguration.swift index 13d708d..8652381 100644 --- a/Sources/SpeziLLMLocal/Configuration/LLMLocalPlatformConfiguration.swift +++ b/Sources/SpeziLLMLocal/Configuration/LLMLocalPlatformConfiguration.swift @@ -7,43 +7,55 @@ // import Foundation -import llama - /// Represents the configuration of the Spezi ``LLMLocalPlatform``. public struct LLMLocalPlatformConfiguration: Sendable { - /// Wrapper around the `ggml_numa_strategy` type of llama.cpp, indicating the non-unified memory access configuration of the device. - public enum NonUniformMemoryAccess: UInt32, Sendable { - case disabled - case distributed - case isolated - case numaCtl - case mirror - case count + /// Represents the memory limit for the MLX GPU. + public struct MemoryLimit: Sendable { + /// The memory limit in MB. + let limit: Int + /// Calls to malloc will wait on scheduled tasks if the limit is exceeded. If + /// there are no more scheduled tasks an error will be raised if `relaxed` + /// is false or memory will be allocated (including the potential for + /// swap) if `relaxed` is true. + /// + /// The memory limit defaults to 1.5 times the maximum recommended working set + /// size reported by the device ([recommendedMaxWorkingSetSize](https://developer.apple.com/documentation/metal/mtldevice/2369280-recommendedmaxworkingsetsize)) + let relaxed: Bool - var wrappedValue: ggml_numa_strategy { - .init(rawValue: self.rawValue) + /// Creates the `MemoryLimit` which configures the GPU used by MLX. + /// + /// - Parameters: + /// - limit: The memory limit in MB. + /// - relaxed: See `relaxed` in ``LLMLocalPlatformConfiguration/MemoryLimit``. + public init(limit: Int, relaxed: Bool = false) { + self.limit = limit + self.relaxed = relaxed } } - + /// The cache limit in MB, to disable set limit to `0`. + let cacheLimit: Int? + /// The memory limit for the GPU used by MLX. + let memoryLimit: MemoryLimit? /// The task priority of the initiated LLM inference tasks. let taskPriority: TaskPriority - /// Indicates the non-unified memory access configuration of the device. - let nonUniformMemoryAccess: NonUniformMemoryAccess /// Creates the ``LLMLocalPlatformConfiguration`` which configures the Spezi ``LLMLocalPlatform``. /// /// - Parameters: + /// - cacheLimit: The cache limit for the GPU used by MLX, defaults to `nil`. + /// - memoryLimit: The memory limit for the GPU used by MLX, defaults to `nil`. /// - taskPriority: The task priority of the initiated LLM inference tasks, defaults to `.userInitiated`. - /// - nonUniformMemoryAccess: Indicates if this is a device with non-unified memory access. public init( - taskPriority: TaskPriority = .userInitiated, - nonUniformMemoryAccess: NonUniformMemoryAccess = .disabled + cacheLimit: Int? = nil, + memoryLimit: MemoryLimit? = nil, + taskPriority: TaskPriority = .userInitiated ) { + self.cacheLimit = cacheLimit + self.memoryLimit = memoryLimit self.taskPriority = taskPriority - self.nonUniformMemoryAccess = nonUniformMemoryAccess } } diff --git a/Sources/SpeziLLMLocal/Configuration/LLMLocalSamplingParameters.swift b/Sources/SpeziLLMLocal/Configuration/LLMLocalSamplingParameters.swift index bd1f941..24b3474 100644 --- a/Sources/SpeziLLMLocal/Configuration/LLMLocalSamplingParameters.swift +++ b/Sources/SpeziLLMLocal/Configuration/LLMLocalSamplingParameters.swift @@ -7,351 +7,36 @@ // import Foundation -import llama /// Represents the sampling parameters of the LLM. -/// -/// Internally, these data points are passed as a llama.cpp `llama_sampling_params` C struct to the LLM. -public struct LLMLocalSamplingParameters: Sendable { // swiftlint:disable:this type_body_length - /// Helper enum for the Mirostat sampling method - public enum Mirostat { - init(rawValue: Int, targetEntropy: Float = 5.0, learningRate: Float = 0.1) { - switch rawValue { - case 0: - self = .disabled - case 1: - self = .v1(targetEntropy: targetEntropy, learningRate: learningRate) - case 2: - self = .v2(targetEntropy: targetEntropy, learningRate: learningRate) - default: - self = .disabled - } - } - - - case disabled - case v1(targetEntropy: Float, learningRate: Float) // swiftlint:disable:this identifier_name - case v2(targetEntropy: Float, learningRate: Float) // swiftlint:disable:this identifier_name - - - var rawValue: Int { - switch self { - case .disabled: - return 0 - case .v1: - return 1 - case .v2: - return 2 - } - } - } - - public struct ClassifierFreeGuidance { - let negativePrompt: String? - let scale: Float - - - public init(negativePrompt: String? = nil, scale: Float = 1.0) { - self.negativePrompt = negativePrompt - self.scale = scale - } - } - - - /// Wrapped C struct from the llama.cpp library, later-on passed to the LLM. - private var wrapped: llama_sampling_params - - - /// Sampling parameters in llama.cpp's low-level C representation. - var llamaCppRepresentation: llama_sampling_params { - wrapped - } - - var llamaCppSamplingContext: UnsafeMutablePointer? { - llama_sampling_init(wrapped) - } - - /// Number of previous tokens to remember. - var rememberTokens: Int32 { - get { - wrapped.n_prev - } - set { - wrapped.n_prev = newValue - } - } - - /// If greater than 0, output the probabilities of top n\_probs tokens. - var outputProbabilities: Int32 { - get { - wrapped.n_probs - } - set { - wrapped.n_probs = newValue - } - } - - /// Top-K Sampling: K most likely next words (<= 0 to use vocab size). - var topK: Int32 { - get { - wrapped.top_k - } - set { - wrapped.top_k = newValue - } - } - +public struct LLMLocalSamplingParameters: Sendable { /// Top-p Sampling: Smallest possible set of words whose cumulative probability exceeds the probability p (1.0 = disabled). - var topP: Float { - get { - wrapped.top_p - } - set { - wrapped.top_p = newValue - } - } - - /// Min-p Sampling (0.0 = disabled). - var minP: Float { - get { - wrapped.min_p - } - set { - wrapped.min_p = newValue - } - } - - /// Tail Free Sampling (1.0 = disabled). - var tfs: Float { - get { - wrapped.tfs_z - } - set { - wrapped.tfs_z = newValue - } - } - - /// Locally Typical Sampling. - var typicalP: Float { - get { - wrapped.typical_p - } - set { - wrapped.typical_p = newValue - } - } - + let topP: Float /// Temperature Sampling: A higher value indicates more creativity of the model but also more hallucinations. - var temperature: Float { - get { - wrapped.temp - } - set { - wrapped.temp = newValue - } - } - - /// Last n tokens to penalize (0 = disable penalty, -1 = context size). - var penaltyLastTokens: Int32 { - get { - wrapped.penalty_last_n - } - set { - wrapped.penalty_last_n = newValue - } - } - - /// Penalize repeated tokens (1.0 = disabled). - var penaltyRepeat: Float { - get { - wrapped.penalty_repeat - } - set { - wrapped.penalty_repeat = newValue - } - } - - /// Penalize frequency (0.0 = disabled). - var penaltyFrequency: Float { - get { - wrapped.penalty_repeat - } - set { - wrapped.penalty_repeat = newValue - } - } - - /// Presence penalty (0.0 = disabled). - var penaltyPresence: Float { - get { - wrapped.penalty_present - } - set { - wrapped.penalty_present = newValue - } - } - - /// Penalize new lines. - var penalizeNewLines: Bool { - get { - wrapped.penalize_nl - } - set { - wrapped.penalize_nl = newValue - } - } - - /// Mirostat sampling. - var mirostat: Mirostat { - get { - .init( - rawValue: Int(wrapped.mirostat), - targetEntropy: wrapped.mirostat_tau, - learningRate: wrapped.mirostat_eta - ) - } - set { - wrapped.mirostat = Int32(newValue.rawValue) - - if case .v1(let targetEntropy, let learningRate) = mirostat { - wrapped.mirostat_tau = targetEntropy - wrapped.mirostat_eta = learningRate - } else if case .v2(let targetEntropy, let learningRate) = mirostat { - wrapped.mirostat_tau = targetEntropy - wrapped.mirostat_eta = learningRate - } else { - wrapped.mirostat_tau = 5.0 - wrapped.mirostat_eta = 0.1 - } - } - } - - // C++ vector doesn't conform to Swift sequence on VisionOS SDK (Swift C++ Interop bug), - // therefore requiring workaround for VisionSDK - #if !os(visionOS) - /// Classifier-Free Guidance. - var cfg: ClassifierFreeGuidance { - get { - .init( - negativePrompt: String(wrapped.cfg_negative_prompt), - scale: wrapped.cfg_scale - ) - } - set { - if let negativePrompt = newValue.negativePrompt { - wrapped.cfg_negative_prompt = std.string(negativePrompt) - } - wrapped.cfg_scale = newValue.scale - } - } + let temperature: Float + /// Penalize repeated tokens (nil = disabled). + let penaltyRepeat: Float? + /// Number of tokens to consider for repetition penalty + let repetitionContextSize: Int - - /// Creates the ``LLMLocalContextParameters`` which wrap the underlying llama.cpp `llama_context_params` C struct. - /// Is passed to the underlying llama.cpp model in order to configure the context of the LLM. - /// - /// - Parameters: - /// - rememberTokens: Number of previous tokens to remember. - /// - outputProbabilities: If greater than 0, output the probabilities of top n\_probs tokens. - /// - topK: Top-K Sampling: K most likely next words (<= 0 to use vocab size). - /// - topP: Top-p Sampling: Smallest possible set of words whose cumulative probability exceeds the probability p (1.0 = disabled). - /// - minP: Min-p Sampling (0.0 = disabled). - /// - tfs: Tail Free Sampling (1.0 = disabled). - /// - typicalP: Locally Typical Sampling. - /// - temperature: Temperature Sampling: A higher value indicates more creativity of the model but also more hallucinations. - /// - penaltyLastTokens: Last n tokens to penalize (0 = disable penalty, -1 = context size). - /// - penaltyRepeat: Penalize repeated tokens (1.0 = disabled). - /// - penaltyFrequency: Penalize frequency (0.0 = disabled). - /// - penaltyPresence: Presence penalty (0.0 = disabled). - /// - penalizeNewLines: Penalize new lines. - /// - mirostat: Mirostat sampling. - /// - cfg: Classifier-Free Guidance. - public init( - rememberTokens: Int32 = 256, - outputProbabilities: Int32 = 0, - topK: Int32 = 40, - topP: Float = 0.95, - minP: Float = 0.05, - tfs: Float = 1.0, - typicalP: Float = 1.0, - temperature: Float = 0.8, - penaltyLastTokens: Int32 = 64, - penaltyRepeat: Float = 1.1, - penaltyFrequency: Float = 0.0, - penaltyPresence: Float = 0.0, - penalizeNewLines: Bool = true, - mirostat: Mirostat = .disabled, - cfg: ClassifierFreeGuidance = .init() - ) { - self.wrapped = llama_sampling_params() - - self.rememberTokens = rememberTokens - self.outputProbabilities = outputProbabilities - self.topK = topK - self.topP = topP - self.minP = minP - self.tfs = tfs - self.typicalP = typicalP - self.temperature = temperature - self.penaltyLastTokens = penaltyLastTokens - self.penaltyRepeat = penaltyRepeat - self.penaltyFrequency = penaltyFrequency - self.penaltyPresence = penaltyPresence - self.penalizeNewLines = penalizeNewLines - self.mirostat = mirostat - self.cfg = cfg - } - #else - /// Creates the ``LLMLocalContextParameters`` which wrap the underlying llama.cpp `llama_context_params` C struct. - /// Is passed to the underlying llama.cpp model in order to configure the context of the LLM. + + /// Creates the ``LLMLocalContextParameters`` /// /// - Parameters: - /// - rememberTokens: Number of previous tokens to remember. - /// - outputProbabilities: If greater than 0, output the probabilities of top n\_probs tokens. - /// - topK: Top-K Sampling: K most likely next words (<= 0 to use vocab size). /// - topP: Top-p Sampling: Smallest possible set of words whose cumulative probability exceeds the probability p (1.0 = disabled). - /// - minP: Min-p Sampling (0.0 = disabled). - /// - tfs: Tail Free Sampling (1.0 = disabled). - /// - typicalP: Locally Typical Sampling. /// - temperature: Temperature Sampling: A higher value indicates more creativity of the model but also more hallucinations. - /// - penaltyLastTokens: Last n tokens to penalize (0 = disable penalty, -1 = context size). - /// - penaltyRepeat: Penalize repeated tokens (1.0 = disabled). - /// - penaltyFrequency: Penalize frequency (0.0 = disabled). - /// - penaltyPresence: Presence penalty (0.0 = disabled). - /// - penalizeNewLines: Penalize new lines. - /// - mirostat: Mirostat sampling. + /// - penaltyRepeat: Penalize repeated tokens (nil = disabled). + /// - repetitionContextSize: Number of tokens to consider for repetition penalty public init( - rememberTokens: Int32 = 256, - outputProbabilities: Int32 = 0, - topK: Int32 = 40, - topP: Float = 0.95, - minP: Float = 0.05, - tfs: Float = 1.0, - typicalP: Float = 1.0, - temperature: Float = 0.8, - penaltyLastTokens: Int32 = 64, - penaltyRepeat: Float = 1.1, - penaltyFrequency: Float = 0.0, - penaltyPresence: Float = 0.0, - penalizeNewLines: Bool = true, - mirostat: Mirostat = .disabled + topP: Float = 1.0, + temperature: Float = 0.6, + penaltyRepeat: Float? = nil, + repetitionContextSize: Int = 20 ) { - self.wrapped = llama_sampling_params() - - self.rememberTokens = rememberTokens - self.outputProbabilities = outputProbabilities - self.topK = topK self.topP = topP - self.minP = minP - self.tfs = tfs - self.typicalP = typicalP self.temperature = temperature - self.penaltyLastTokens = penaltyLastTokens self.penaltyRepeat = penaltyRepeat - self.penaltyFrequency = penaltyFrequency - self.penaltyPresence = penaltyPresence - self.penalizeNewLines = penalizeNewLines - self.mirostat = mirostat + self.repetitionContextSize = repetitionContextSize } - #endif } diff --git a/Sources/SpeziLLMLocal/Helpers/LLMModel+numParameters.swift b/Sources/SpeziLLMLocal/Helpers/LLMModel+numParameters.swift new file mode 100644 index 0000000..aa01bd4 --- /dev/null +++ b/Sources/SpeziLLMLocal/Helpers/LLMModel+numParameters.swift @@ -0,0 +1,30 @@ +// +// This source file is part of the Stanford Spezi open source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + +import MLXNN + + +extension Module { + /// Compute the number of parameters in a possibly quantized model + public func numParameters() -> Int { + leafModules() + .flattenedValues() + .map { mod -> Int in + if let quantized = mod as? QuantizedLinear { + return quantized.scales.size * quantized.groupSize + } else if let quantized = mod as? QuantizedEmbedding { + return quantized.scales.size * quantized.groupSize + } else { + return mod.parameters() + .flattenedValues() + .reduce(0) { $0 + $1.size } + } + } + .reduce(0, +) + } +} diff --git a/Sources/SpeziLLMLocal/Helpers/String+Cxx.swift b/Sources/SpeziLLMLocal/Helpers/String+Cxx.swift deleted file mode 100644 index 4367986..0000000 --- a/Sources/SpeziLLMLocal/Helpers/String+Cxx.swift +++ /dev/null @@ -1,30 +0,0 @@ -// -// This source file is part of the Stanford Spezi open source project -// -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) -// -// SPDX-License-Identifier: MIT -// - -import Foundation - - -extension String { - /// Initializes a Swift `String` from a C++ `string`. - /// - /// - Parameters: - /// - cxxString: The given C++ `string` - /// - /// In the Release build mode, the Swift compiler is unable to choose the correct String initializer from the Swift stdlib. - /// Therefore, manual `String `extension by SpeziLLM that mirrors the C++ interop implementation within the Swift stdlib: https://github.com/apple/swift/blob/cf2a338afca54a787d59b83db6238b1568215b94/stdlib/public/Cxx/std/String.swift#L231-L239 - init(_ cxxString: std.string) { - let buffer = UnsafeBufferPointer( - start: cxxString.__c_strUnsafe(), - count: cxxString.size() - ) - self = buffer.withMemoryRebound(to: UInt8.self) { - String(decoding: $0, as: UTF8.self) - } - withExtendedLifetime(cxxString) {} - } -} diff --git a/Sources/SpeziLLMLocal/LLMLocalError.swift b/Sources/SpeziLLMLocal/LLMLocalError.swift index 4f3b96a..c6f7ada 100644 --- a/Sources/SpeziLLMLocal/LLMLocalError.swift +++ b/Sources/SpeziLLMLocal/LLMLocalError.swift @@ -1,7 +1,7 @@ // // This source file is part of the Stanford Spezi open source project // -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) // // SPDX-License-Identifier: MIT // diff --git a/Sources/SpeziLLMLocal/LLMLocalPlatform.swift b/Sources/SpeziLLMLocal/LLMLocalPlatform.swift index e6adc77..a8930de 100644 --- a/Sources/SpeziLLMLocal/LLMLocalPlatform.swift +++ b/Sources/SpeziLLMLocal/LLMLocalPlatform.swift @@ -1,13 +1,13 @@ // // This source file is part of the Stanford Spezi open source project // -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) // // SPDX-License-Identifier: MIT // import Foundation -import llama +import MLX import Spezi import SpeziFoundation import SpeziLLM @@ -39,13 +39,10 @@ import SpeziLLM /// } /// ``` public actor LLMLocalPlatform: LLMPlatform, DefaultInitializable { - /// Enforce only one concurrent execution of a local LLM. - private let semaphore = AsyncSemaphore(value: 1) let configuration: LLMLocalPlatformConfiguration @MainActor public var state: LLMPlatformState = .idle - /// Creates an instance of the ``LLMLocalPlatform``. /// /// - Parameters: @@ -59,34 +56,23 @@ public actor LLMLocalPlatform: LLMPlatform, DefaultInitializable { self.init(configuration: .init()) } - public nonisolated func configure() { - // Initialize the llama.cpp backend - llama_backend_init() - llama_numa_init(configuration.nonUniformMemoryAccess.wrappedValue) +#if targetEnvironment(simulator) + assertionFailure("SpeziLLMLocal: Code cannot be run on simulator.") +#endif + if let cacheLimit = configuration.cacheLimit { + MLX.GPU.set(cacheLimit: cacheLimit * 1024 * 1024) + } + if let memoryLimit = configuration.memoryLimit { + MLX.GPU.set(memoryLimit: memoryLimit.limit, relaxed: memoryLimit.relaxed) + } } public nonisolated func callAsFunction(with llmSchema: LLMLocalSchema) -> LLMLocalSession { LLMLocalSession(self, schema: llmSchema) } - nonisolated func exclusiveAccess() async throws { - try await semaphore.waitCheckingCancellation() - await MainActor.run { - state = .processing - } - } - - nonisolated func signal() async { - semaphore.signal() - await MainActor.run { - state = .idle - } - } - - deinit { - // Frees the llama.cpp backend - llama_backend_free() + MLX.GPU.clearCache() } } diff --git a/Sources/SpeziLLMLocal/LLMLocalSchema.swift b/Sources/SpeziLLMLocal/LLMLocalSchema.swift index 40204d2..3bf3ca2 100644 --- a/Sources/SpeziLLMLocal/LLMLocalSchema.swift +++ b/Sources/SpeziLLMLocal/LLMLocalSchema.swift @@ -1,12 +1,13 @@ // // This source file is part of the Stanford Spezi open source project // -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) // // SPDX-License-Identifier: MIT // import Foundation +import MLXLLM import SpeziChat import SpeziLLM @@ -20,10 +21,7 @@ import SpeziLLM public struct LLMLocalSchema: LLMSchema { public typealias Platform = LLMLocalPlatform - - /// The on-device `URL` where the model is located. - let modelPath: URL - /// Parameters of the llama.cpp LLM. + /// Closure to properly format the ``LLMLocal/context`` to a `String` which is tokenized and passed to the LLM. let parameters: LLMLocalParameters /// Context parameters of the llama.cpp LLM. let contextParameters: LLMLocalContextParameters @@ -31,31 +29,33 @@ public struct LLMLocalSchema: LLMSchema { let samplingParameters: LLMLocalSamplingParameters /// Closure to properly format the ``LLMLocal/context`` to a `String` which is tokenized and passed to the LLM. let formatChat: (@Sendable (LLMContext) throws -> String) + /// Indicates if the inference output by the ``LLMLocalSession`` should automatically be inserted into the ``LLMLocalSession/context``. public let injectIntoContext: Bool - + /// The models configuration which is based on `mlx-libraries` + internal let configuration: ModelConfiguration /// Creates an instance of the ``LLMLocalSchema`` containing all necessary configuration for local LLM inference. /// /// - Parameters: - /// - modelPath: A local `URL` where the LLM file is stored. The format of the LLM must be in the llama.cpp `.gguf` format. - /// - parameters: Parameterize the LLM via ``LLMLocalParameters``. - /// - contextParameters: Configure the context of the LLM via ``LLMLocalContextParameters``. - /// - samplingParameters: Parameterize the sampling methods of the LLM via ``LLMLocalSamplingParameters``. + /// - configuration: A local `URL` where the LLM file is stored. The format of the LLM must be in the llama.cpp `.gguf` format. + /// - generateParameters: Parameters controlling the LLM generation process. + /// - maxTokens: Maximum number of tokens to generate in a single output, defaults to 2048. + /// - displayEveryNTokens: Interval for displaying output after every N tokens generated, defaults to 4 (improve by ~15% compared to update at every token). /// - injectIntoContext: Indicates if the inference output by the ``LLMLocalSession`` should automatically be inserted into the ``LLMLocalSession/context``, defaults to false. /// - formatChat: Closure to properly format the ``LLMLocalSession/context`` to a `String` which is tokenized and passed to the LLM, defaults to Llama2 prompt format. public init( - modelPath: URL, + model: LLMLocalModel, parameters: LLMLocalParameters = .init(), contextParameters: LLMLocalContextParameters = .init(), samplingParameters: LLMLocalSamplingParameters = .init(), injectIntoContext: Bool = false, - formatChat: @escaping (@Sendable (LLMContext) throws -> String) = PromptFormattingDefaults.llama2 + formatChat: @escaping (@Sendable (LLMContext) throws -> String) ) { - self.modelPath = modelPath self.parameters = parameters self.contextParameters = contextParameters self.samplingParameters = samplingParameters - self.injectIntoContext = injectIntoContext self.formatChat = formatChat + self.injectIntoContext = injectIntoContext + self.configuration = .init(id: model.hubID) } } diff --git a/Sources/SpeziLLMLocal/LLMLocalSession+Generate.swift b/Sources/SpeziLLMLocal/LLMLocalSession+Generate.swift new file mode 100644 index 0000000..8f9d6e3 --- /dev/null +++ b/Sources/SpeziLLMLocal/LLMLocalSession+Generate.swift @@ -0,0 +1,113 @@ +// +// This source file is part of the Stanford Spezi open source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + +import Foundation +import MLX +import MLXLLM +import MLXRandom +import os +import SpeziChat +import SpeziLLM + + +extension LLMLocalSession { + // swiftlint:disable:next identifier_name function_body_length + internal func _generate(continuation: AsyncThrowingStream.Continuation) async { + guard let modelContainer = await self.modelContainer else { + Self.logger.error("SpeziLLMLocal: Failed to load `modelContainer`") + await finishGenerationWithError(LLMLocalError.modelNotFound, on: continuation) + return + } + + let modelConfiguration = self.schema.configuration + + guard let formattedChat = try? await schema.formatChat(self.context) else { + Self.logger.error("SpeziLLMLocal: Failed to format chat with given context") + await finishGenerationWithError(LLMLocalError.illegalContext, on: continuation) + return + } + + let prompt = modelConfiguration.prepare(prompt: formattedChat) + let promptTokens = await modelContainer.perform { _, tokenizer in + tokenizer.encode(text: prompt) + } + + MLXRandom.seed(self.schema.contextParameters.seed ?? UInt64(Date.timeIntervalSinceReferenceDate * 1000)) + + let extraEOSTokens = modelConfiguration.extraEOSTokens + + guard await !checkCancellation(on: continuation) else { + return + } + + let parameters: GenerateParameters = .init( + temperature: schema.samplingParameters.temperature, + topP: schema.samplingParameters.topP, + repetitionPenalty: schema.samplingParameters.penaltyRepeat, + repetitionContextSize: schema.samplingParameters.repetitionContextSize + ) + + let (result, tokenizer) = await modelContainer.perform { model, tokenizer in + let result = MLXLLM.generate( + promptTokens: promptTokens, + parameters: parameters, + model: model, + tokenizer: tokenizer, + extraEOSTokens: extraEOSTokens + ) { tokens in + if Task.isCancelled { + return .stop + } + + if tokens.count >= self.schema.parameters.maxOutputLength { + Self.logger.debug("SpeziLLMLocal: Max output length exceeded.") + continuation.finish() + Task { @MainActor in + self.state = .ready + } + return .stop + } + + if schema.injectIntoContext && tokens.count.isMultiple(of: schema.parameters.displayEveryNTokens) { + let lastTokens = Array(tokens.suffix(schema.parameters.displayEveryNTokens)) + let text = tokenizer.decode(tokens: lastTokens) + + Self.logger.debug("SpeziLLMLocal: Yielded token: \(text, privacy: .public)") + continuation.yield(text) + } + + return .more + } + + return (result, tokenizer) + } + + Self.logger.debug( + """ + SpeziLLMLocal: + Prompt Tokens per second: \(result.promptTokensPerSecond, privacy: .public) + Generation tokens per second: \(result.tokensPerSecond, privacy: .public) + """ + ) + + await MainActor.run { + if schema.injectIntoContext { + // Yielding every Nth token may result in missing the final tokens. + let reaminingTokens = result.tokens.count % schema.parameters.displayEveryNTokens + let lastTokens = Array(result.tokens.suffix(reaminingTokens)) + let text = tokenizer.decode(tokens: lastTokens) + continuation.yield(text) + } + + context.append(assistantOutput: result.output, complete: true) + context.completeAssistantStreaming() + continuation.finish() + state = .ready + } + } +} diff --git a/Sources/SpeziLLMLocal/LLMLocalSession+Generation.swift b/Sources/SpeziLLMLocal/LLMLocalSession+Generation.swift deleted file mode 100644 index cbdb0ad..0000000 --- a/Sources/SpeziLLMLocal/LLMLocalSession+Generation.swift +++ /dev/null @@ -1,194 +0,0 @@ -// -// This source file is part of the Stanford Spezi open source project -// -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) -// -// SPDX-License-Identifier: MIT -// - -import Foundation -import llama -import SpeziLLM - - -/// Extension of ``LLMLocalSession`` handling the text generation. -extension LLMLocalSession { - /// Typealias for the llama.cpp `llama_token`. - typealias LLMLocalToken = llama_token - - - /// Based on the input prompt, generate the output with llama.cpp - /// - /// - Parameters: - /// - continuation: A Swift `AsyncThrowingStream` that streams the generated output. - func _generate( // swiftlint:disable:this identifier_name function_body_length cyclomatic_complexity - continuation: AsyncThrowingStream.Continuation - ) async { - Self.logger.debug("SpeziLLMLocal: Local LLM started a new inference") - - await MainActor.run { - self.state = .generating - } - - // Log the most important parameters of the LLM - Self.logger.debug("SpeziLLMLocal: n_length = \(self.schema.parameters.maxOutputLength, privacy: .public), n_ctx = \(self.schema.contextParameters.contextWindowSize, privacy: .public), n_batch = \(self.schema.contextParameters.batchSize, privacy: .public), n_kv_req = \(self.schema.parameters.maxOutputLength, privacy: .public)") - - // Allocate new model context, if not already present - if self.modelContext == nil { - guard let context = llama_new_context_with_model(model, schema.contextParameters.llamaCppRepresentation) else { - Self.logger.error("SpeziLLMLocal: Failed to initialize context") - await finishGenerationWithError(LLMLocalError.generationError, on: continuation) - return - } - self.modelContext = context - } - - // Check if the maximal output generation length is smaller or equals to the context window size. - guard schema.parameters.maxOutputLength <= schema.contextParameters.contextWindowSize else { - Self.logger.error("SpeziLLMLocal: Error: n_kv_req \(self.schema.parameters.maxOutputLength, privacy: .public) > n_ctx, the required KV cache size is not big enough") - await finishGenerationWithError(LLMLocalError.generationError, on: continuation) - return - } - - // Tokenizes the entire context of the LLM - guard let tokens = try? await tokenize() else { - Self.logger.error(""" - SpeziLLMLocal: Tokenization failed as illegal context exists. - Ensure the content of the context is structured in: System Prompt, User prompt, and an - arbitrary number of assistant responses and follow up user prompts. - """) - await finishGenerationWithError(LLMLocalError.illegalContext, on: continuation) - return - } - - guard await !checkCancellation(on: continuation) else { - return - } - - // Check if the input token count is smaller than the context window size decremented by 4 (space for end tokens). - guard tokens.count <= schema.contextParameters.contextWindowSize - 4 else { - Self.logger.error(""" - SpeziLLMLocal: Input prompt is too long with \(tokens.count, privacy: .public) tokens for the configured - context window size of \(self.schema.contextParameters.contextWindowSize, privacy: .public) tokens. - """) - await finishGenerationWithError(LLMLocalError.generationError, on: continuation) - return - } - - // Clear the KV cache in order to free up space for the incoming prompt (as we inject the entire history of the chat again) - llama_kv_cache_clear(self.modelContext) - - var batch = llama_batch_init(Int32(tokens.count), 0, 1) - defer { - llama_batch_free(batch) - } - - // Evaluate the initial prompt - for (tokenIndex, token) in tokens.enumerated() { - llama_batch_add(&batch, token, Int32(tokenIndex), getLlamaSeqIdVector(), false) - } - // llama_decode will output logits only for the last token of the prompt - batch.logits[Int(batch.n_tokens) - 1] = 1 - - guard await !checkCancellation(on: continuation) else { - return - } - - if llama_decode(self.modelContext, batch) != 0 { - Self.logger.error(""" - SpeziLLMLocal: Initial prompt decoding as failed! - """) - await finishGenerationWithError(LLMLocalError.generationError, on: continuation) - return - } - - guard await !checkCancellation(on: continuation) else { - return - } - - // Batch already includes tokens from the input prompt - var batchTokenIndex = batch.n_tokens - var decodedTokens = 0 - - // Calculate the token generation rate - let startTime = Date() - - while decodedTokens <= schema.parameters.maxOutputLength { - guard await !checkCancellation(on: continuation) else { - return - } - - let nextTokenId = sample(batchSize: batch.n_tokens) - - // Either finish the generation once EOS token appears, the maximum output length of the answer is reached or the context window is reached - if nextTokenId == llama_token_eos(self.model) - || decodedTokens == schema.parameters.maxOutputLength - || batchTokenIndex == schema.contextParameters.contextWindowSize { - continuation.finish() - await MainActor.run { - self.state = .ready - } - return - } - - var nextStringPiece = String(llama_token_to_piece(self.modelContext, nextTokenId, true)) - // As first character is sometimes randomly prefixed by a single space (even though prompt has an additional character) - if decodedTokens == 0 && nextStringPiece.starts(with: " ") { - nextStringPiece = String(nextStringPiece.dropFirst()) - } - - // Yield the response from the model to the Stream - Self.logger.debug(""" - SpeziLLMLocal: Yielded token: \(nextStringPiece, privacy: .public) - """) - - // Automatically inject the yielded string piece into the `LLMLocal/context` - if schema.injectIntoContext && nextTokenId != 0 { - let nextStringPiece = nextStringPiece - await MainActor.run { - context.append(assistantOutput: nextStringPiece) - } - } - - if nextTokenId != 0 { - continuation.yield(nextStringPiece) - } - - // Prepare the next batch - llama_batch_clear(&batch) - - // Push generated output token for the next evaluation round - llama_batch_add(&batch, nextTokenId, batchTokenIndex, getLlamaSeqIdVector(), true) - - decodedTokens += 1 - batchTokenIndex += 1 - - // Evaluate the current batch with the transformer model - let decodeOutput = llama_decode(self.modelContext, batch) - if decodeOutput != 0 { // = 0 Success, > 0 Warning, < 0 Error - Self.logger.error("SpeziLLMLocal: Decoding of generated output failed. Output: \(decodeOutput, privacy: .public)") - await finishGenerationWithError(LLMLocalError.generationError, on: continuation) - return - } - } - - let elapsedTime = Date().timeIntervalSince(startTime) - - Self.logger.debug("SpeziLLMLocal: Decoded \(decodedTokens, privacy: .public) tokens in \(String(format: "%.2f", elapsedTime), privacy: .public) s, speed: \(String(format: "%.2f", Double(decodedTokens) / elapsedTime), privacy: .public)) t/s") - - llama_print_timings(self.modelContext) - - continuation.finish() - if schema.injectIntoContext { - await MainActor.run { - context.completeAssistantStreaming() - } - } - - await MainActor.run { - self.state = .ready - } - - Self.logger.debug("SpeziLLMLocal: Local LLM completed an inference") - } -} diff --git a/Sources/SpeziLLMLocal/LLMLocalSession+Sampling.swift b/Sources/SpeziLLMLocal/LLMLocalSession+Sampling.swift deleted file mode 100644 index 942e282..0000000 --- a/Sources/SpeziLLMLocal/LLMLocalSession+Sampling.swift +++ /dev/null @@ -1,46 +0,0 @@ -// -// This source file is part of the Stanford Spezi open source project -// -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) -// -// SPDX-License-Identifier: MIT -// - -import Foundation -import llama - - -extension LLMLocalSession { - /// Based on the current state of the context, sample the to be inferred output via the temperature method - /// - /// - Parameters: - /// - batchSize: The current size of the `llama_batch` - /// - Returns: A sampled `LLMLocalToken` - func sample(batchSize: Int32) -> LLMLocalToken { - let nVocab = llama_n_vocab(model) - let logits = llama_get_logits_ith(self.modelContext, batchSize - 1) - - var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(nVocab)) - - for tokenId in 0 ..< nVocab { - candidates.append(llama_token_data(id: tokenId, logit: logits?[Int(tokenId)] ?? 0, p: 0.0)) - } - - var candidatesP: llama_token_data_array = .init( - data: candidates.withUnsafeMutableBytes { $0.baseAddress?.assumingMemoryBound(to: llama_token_data.self) }, // &candidates - size: candidates.count, - sorted: false - ) - - // Sample via the temperature method - let minKeep = Int(max(1, schema.samplingParameters.outputProbabilities)) - llama_sample_top_k(modelContext, &candidatesP, schema.samplingParameters.topK, minKeep) - llama_sample_tail_free(modelContext, &candidatesP, schema.samplingParameters.tfs, minKeep) - llama_sample_typical(modelContext, &candidatesP, schema.samplingParameters.typicalP, minKeep) - llama_sample_top_p(modelContext, &candidatesP, schema.samplingParameters.topP, minKeep) - llama_sample_min_p(modelContext, &candidatesP, schema.samplingParameters.minP, minKeep) - llama_sample_temp(modelContext, &candidatesP, schema.samplingParameters.temperature) - - return llama_sample_token(modelContext, &candidatesP) - } -} diff --git a/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift b/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift index c60d279..4ebb457 100644 --- a/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift +++ b/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift @@ -1,48 +1,63 @@ // // This source file is part of the Stanford Spezi open source project // -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) // // SPDX-License-Identifier: MIT // -import llama +import Foundation +import Hub +import MLXLLM extension LLMLocalSession { - /// Set up the local LLM execution environment via llama.cpp - /// - /// - Parameters: - /// - continuation: A Swift `AsyncThrowingStream` that streams the generated output. - /// - Returns: `true` if the setup was successful, `false` otherwise. - func setup(continuation: AsyncThrowingStream.Continuation) async -> Bool { + private func verifyModelDownload() -> Bool { + let repo = Hub.Repo(id: self.schema.configuration.name) + let url = HubApi.shared.localRepoLocation(repo) + let modelFileExtension = ".safetensors" + + do { + let contents = try FileManager.default.contentsOfDirectory(atPath: url.path()) + return contents.contains { $0.hasSuffix(modelFileExtension) } + } catch { + return false + } + } + + // swiftlint:disable:next identifier_name + internal func _setup(continuation: AsyncThrowingStream.Continuation?) async -> Bool { Self.logger.debug("SpeziLLMLocal: Local LLM is being initialized") + await MainActor.run { - state = .loading + self.state = .loading } - guard let model = llama_load_model_from_file(schema.modelPath.path().cString(using: .utf8), schema.parameters.llamaCppRepresentation) else { - await finishGenerationWithError(LLMLocalError.modelNotFound, on: continuation) + guard verifyModelDownload() else { + if let continuation { + await finishGenerationWithError(LLMLocalError.modelNotFound, on: continuation) + } Self.logger.error("SpeziLLMLocal: Local LLM file could not be opened, indicating that the model file doesn't exist") return false } - /// Check if model was trained for the configured context window size - guard schema.contextParameters.contextWindowSize <= llama_n_ctx_train(model) else { - await finishGenerationWithError(LLMLocalError.contextSizeMismatch, on: continuation) - Self.logger.error(""" - SpeziLLMLocal: Model was trained on only \(llama_n_ctx_train(model), privacy: .public) context tokens, - not the configured \(self.schema.contextParameters.contextWindowSize, privacy: .public) context tokens - """) + do { + let modelContainer = try await loadModelContainer(configuration: self.schema.configuration) + + let numParams = await modelContainer.perform { [] model, _ in + model.numParameters() + } + + await MainActor.run { + self.modelContainer = modelContainer + self.numParameters = numParams + self.state = .ready + } + } catch { + continuation?.yield(with: .failure(error)) + Self.logger.error("SpeziLLMLocal: Failed to load local `modelContainer`") return false } - - self.model = model - - await MainActor.run { - state = .ready - } - Self.logger.debug("SpeziLLMLocal: Local LLM finished initializing, now ready to use") return true } } diff --git a/Sources/SpeziLLMLocal/LLMLocalSession+Tokenization.swift b/Sources/SpeziLLMLocal/LLMLocalSession+Tokenization.swift deleted file mode 100644 index 5ea001f..0000000 --- a/Sources/SpeziLLMLocal/LLMLocalSession+Tokenization.swift +++ /dev/null @@ -1,81 +0,0 @@ -// -// This source file is part of the Stanford Spezi open source project -// -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) -// -// SPDX-License-Identifier: MIT -// - -import Foundation -import llama - - -/// Extension of ``LLMLocalSession`` handling the text tokenization. -extension LLMLocalSession { - /// Converts the current context of the model to the individual `LLMLocalToken`'s based on the model's dictionary. - /// This is a required tasks as LLMs internally processes tokens. - /// - /// - Returns: The tokenized `String` as `LLMLocalToken`'s. - func tokenize() async throws -> [LLMLocalToken] { - // Format the chat into a prompt that conforms to the prompt structure of the respective LLM - let formattedChat = try await schema.formatChat(self.context) - - // C++ vector doesn't conform to Swift sequence on VisionOS SDK (Swift C++ Interop bug), - // therefore requiring workaround for VisionSDK - #if !os(visionOS) - var tokens: [LLMLocalToken] = .init( - llama_tokenize_with_context(self.modelContext, std.string(formattedChat), schema.parameters.addBosToken, true) - ) - #else - // Swift String to C++ String buggy on VisionOS, workaround via C-based `char` array - guard let cString = formattedChat.cString(using: .utf8) else { - fatalError("SpeziLLMLocal: Couldn't bridge the LLM Swift-based String context to a C-based String.") - } - - let cxxTokensVector = llama_tokenize_with_context_from_char_array(self.modelContext, cString, schema.parameters.addBosToken, true) - - // Get C array from C++ vector containing the tokenized content - guard var cxxTokensArray = vectorToIntArray(cxxTokensVector) else { - fatalError("SpeziLLMLocal: Couldn't get C array containing the tokenized content from C++ vector.") - } - - // Extract tokens from C array to a Swift array - var tokens: [LLMLocalToken] = [] - - for _ in 0...cxxTokensVector.size() { - tokens.append(cxxTokensArray.pointee) - cxxTokensArray = cxxTokensArray.advanced(by: 1) - } - #endif - - // Truncate tokens if there wouldn't be enough context size for the generated output - if tokens.count > Int(schema.contextParameters.contextWindowSize) - schema.parameters.maxOutputLength { - tokens = Array(tokens.suffix(Int(schema.contextParameters.contextWindowSize) - schema.parameters.maxOutputLength)) - } - - // Output generation shouldn't run without any tokens - if tokens.isEmpty { - tokens.append(llama_token_bos(self.model)) - Self.logger.warning(""" - SpeziLLMLocal: The input prompt didn't map to any tokens, so the prompt was considered empty. - To mediate this issue, a BOS token was added to the prompt so that the output generation - doesn't run without any tokens. - """) - } - - return tokens - } - - /// Converts an array of `LLMLocalToken`s to an array of tupels of `LLMLocalToken`s as well as their `String` representation. - /// - /// - Parameters: - /// - tokens: An array of `LLMLocalToken`s that should be detokenized. - /// - Returns: An array of tupels of `LLMLocalToken`s as well as their `String` representation. - /// - /// - Note: Used only for debug purposes - func detokenize(tokens: [LLMLocalToken]) -> [(LLMLocalToken, String)] { - tokens.reduce(into: [(LLMLocalToken, String)]()) { partialResult, token in - partialResult.append((token, String(llama_token_to_piece(self.modelContext, token, true)))) - } - } -} diff --git a/Sources/SpeziLLMLocal/LLMLocalSession.swift b/Sources/SpeziLLMLocal/LLMLocalSession.swift index 6771ba8..026859f 100644 --- a/Sources/SpeziLLMLocal/LLMLocalSession.swift +++ b/Sources/SpeziLLMLocal/LLMLocalSession.swift @@ -1,12 +1,16 @@ // // This source file is part of the Stanford Spezi open source project // -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) // // SPDX-License-Identifier: MIT // + import Foundation +import MLX +import MLXLLM +import MLXRandom import os import SpeziChat import SpeziLLM @@ -65,16 +69,19 @@ public final class LLMLocalSession: LLMSession, @unchecked Sendable { let platform: LLMLocalPlatform let schema: LLMLocalSchema + @ObservationIgnored private var modelExist: Bool { + false + } + /// A task managing the ``LLMLocalSession`` output generation. @ObservationIgnored private var task: Task<(), Never>? @MainActor public var state: LLMState = .uninitialized @MainActor public var context: LLMContext = [] - /// A pointer to the allocated model via llama.cpp. - @ObservationIgnored var model: OpaquePointer? - /// A pointer to the allocated model context from llama.cpp. - @ObservationIgnored var modelContext: OpaquePointer? + @MainActor public var numParameters: Int? + @MainActor public var modelConfiguration: ModelConfiguration? + @MainActor public var modelContainer: ModelContainer? /// Creates an instance of a ``LLMLocalSession`` responsible for LLM inference. @@ -95,25 +102,28 @@ public final class LLMLocalSession: LLMSession, @unchecked Sendable { } } + /// Initializes the model in advance. + /// Calling this method before user interaction prepares the model, which leads to reduced response time for the first prompt. + public func setup() async throws { + guard await _setup(continuation: nil) else { + throw LLMLocalError.modelNotReadyYet + } + } + + /// Based on the input prompt, generate the output. + /// - Returns: A Swift `AsyncThrowingStream` that streams the generated output. @discardableResult public func generate() async throws -> AsyncThrowingStream { - try await platform.exclusiveAccess() - let (stream, continuation) = AsyncThrowingStream.makeStream(of: String.self) - // Execute the output generation of the LLM task = Task(priority: platform.configuration.taskPriority) { - // Unregister as soon as `Task` finishes - defer { - Task { - await platform.signal() - } - } - - // Setup the model, if not already done - if model == nil { - guard await setup(continuation: continuation) else { + if await state == .uninitialized { + guard await _setup(continuation: continuation) else { + await MainActor.run { + state = .error(error: LLMLocalError.modelNotReadyYet) + } + await finishGenerationWithError(LLMLocalError.modelNotReadyYet, on: continuation) return } } @@ -122,18 +132,22 @@ public final class LLMLocalSession: LLMSession, @unchecked Sendable { return } - // Execute the inference + await MainActor.run { + self.state = .generating + } + + // Execute the output generation of the LLM await _generate(continuation: continuation) } return stream } + public func cancel() { task?.cancel() } - deinit { cancel() } diff --git a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager+DefaultUrls.swift b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager+DefaultUrls.swift deleted file mode 100644 index 07dad7a..0000000 --- a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager+DefaultUrls.swift +++ /dev/null @@ -1,92 +0,0 @@ -// -// This source file is part of the Stanford Spezi open source project -// -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) -// -// SPDX-License-Identifier: MIT -// - -import Foundation - - -extension LLMLocalDownloadManager { - /// Defaults of possible LLMs to download via the ``LLMLocalDownloadManager``. - public enum LLMUrlDefaults { - /// LLama 3 8B model with `Q4_K_M` quantization in its instruct variation (~5 GB) - public static var llama3InstructModelUrl: URL { - guard let url = URL(string: "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf") else { - preconditionFailure(""" - SpeziLLM: Invalid LLMUrlDefaults LLM download URL. - """) - } - - return url - } - - /// LLama 2 7B model with `Q4_K_M` quantization in its chat variation (~3.5GB) - public static var llama2ChatModelUrl: URL { - guard let url = URL(string: "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf") else { - preconditionFailure(""" - SpeziLLM: Invalid LLMUrlDefaults LLM download URL. - """) - } - - return url - } - - /// LLama 2 13B model with `Q4_K_M` quantization in its chat variation (~7GB) - public static var llama2Chat13BModelUrl: URL { - guard let url = URL(string: "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/resolve/main/llama-2-13b-chat.ggmlv3.q4_K_M.bin") else { - preconditionFailure(""" - SpeziLLM: Invalid LLMUrlDefaults LLM download URL. - """) - } - - return url - } - - /// Phi-2 model with `Q5_K_M` quantization (~2GB) - public static var phi2ModelUrl: URL { - guard let url = URL(string: "https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q5_K_M.gguf") else { - preconditionFailure(""" - SpeziLLM: Invalid LLMUrlDefaults LLM download URL. - """) - } - - return url - } - - /// Gemma 7B model with `Q4_K_M` quantization (~5GB) - public static var gemma7BModelUrl: URL { - guard let url = URL(string: "https://huggingface.co/rahuldshetty/gemma-7b-it-gguf-quantized/resolve/main/gemma-7b-it-Q4_K_M.gguf") else { - preconditionFailure(""" - SpeziLLM: Invalid LLMUrlDefaults LLM download URL. - """) - } - - return url - } - - /// Gemma 2B model with `Q4_K_M` quantization (~1.5GB) - public static var gemma2BModelUrl: URL { - guard let url = URL(string: "https://huggingface.co/rahuldshetty/gemma-2b-gguf-quantized/resolve/main/gemma-2b-Q4_K_M.gguf") else { - preconditionFailure(""" - SpeziLLM: Invalid LLMUrlDefaults LLM download URL. - """) - } - - return url - } - - /// Tiny LLama 1.1B model with `Q5_K_M` quantization in its chat variation (~800MB) - public static var tinyLLama2ModelUrl: URL { - guard let url = URL(string: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf") else { - preconditionFailure(""" - SpeziLLM: Invalid LLMUrlDefaults LLM download URL. - """) - } - - return url - } - } -} diff --git a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift index e2788b2..1e1b2b8 100644 --- a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift +++ b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift @@ -7,10 +7,12 @@ // import Foundation +import Hub +import MLXLLM import Observation +import SpeziLLMLocal import SpeziViews - /// Manages the download and storage of Large Language Models (LLM) to the local device. /// /// One configures the ``LLMLocalDownloadManager`` via the ``LLMLocalDownloadManager/init(llmDownloadUrl:llmStorageUrl:)`` initializer, @@ -25,7 +27,7 @@ public final class LLMLocalDownloadManager: NSObject { /// An enum containing all possible states of the ``LLMLocalDownloadManager``. public enum DownloadState: Equatable { case idle - case downloading(progress: Double) + case downloading(progress: Progress) case downloaded(storageUrl: URL) case error(LocalizedError) @@ -41,41 +43,73 @@ public final class LLMLocalDownloadManager: NSObject { } } - /// The delegate handling the download manager tasks. - @ObservationIgnored private var downloadDelegate: LLMLocalDownloadManagerDelegate? // swiftlint:disable:this weak_delegate /// The `URLSessionDownloadTask` that handles the download of the model. - @ObservationIgnored private var downloadTask: URLSessionDownloadTask? - /// Remote `URL` from where the LLM file should be downloaded. - private let llmDownloadUrl: URL - /// Local `URL` where the downloaded model is stored. - let llmStorageUrl: URL + @ObservationIgnored private var downloadTask: Task<(), Never>? /// Indicates the current state of the ``LLMLocalDownloadManager``. @MainActor public var state: DownloadState = .idle + private let modelConfiguration: ModelConfiguration + @ObservationIgnored public var modelExists: Bool { + LLMLocalDownloadManager.modelExsist(model: .custom(id: modelConfiguration.name)) + } - /// Creates a ``LLMLocalDownloadManager`` that helps with downloading LLM files from remote servers. + /// Initializes a ``LLMLocalDownloadManager`` instance to manage the download of Large Language Model (LLM) files from remote servers. /// /// - Parameters: - /// - llmDownloadUrl: The remote `URL` from where the LLM file should be downloaded. - /// - llmStorageUrl: The local `URL` where the LLM file should be stored. - public init( - llmDownloadUrl: URL = LLMUrlDefaults.llama2ChatModelUrl, - llmStorageUrl: URL = .cachesDirectory.appending(path: "llm.gguf") - ) { - self.llmDownloadUrl = llmDownloadUrl - self.llmStorageUrl = llmStorageUrl + /// - modelID: The Huggingface model ID of the LLM that needs to be downloaded. + public init(model: LLMLocalModel) { + self.modelConfiguration = .init(id: model.hubID) } + /// Checks if a model is already downloaded to the local device. + /// + /// - Parameter model: The model to check for local existence. + /// - Returns: A Boolean value indicating whether the model exists on the device. + public static func modelExsist(model: LLMLocalModel) -> Bool { + let repo = Hub.Repo(id: model.hubID) + let url = HubApi.shared.localRepoLocation(repo) + let modelFileExtension = ".safetensors" + + do { + let contents = try FileManager.default.contentsOfDirectory(atPath: url.path()) + return contents.contains { $0.hasSuffix(modelFileExtension) } + } catch { + return false + } + } /// Starts a `URLSessionDownloadTask` to download the specified model. public func startDownload() { - downloadTask?.cancel() - - downloadDelegate = LLMLocalDownloadManagerDelegate(manager: self, storageUrl: llmStorageUrl) - let session = URLSession(configuration: .default, delegate: downloadDelegate, delegateQueue: nil) - downloadTask = session.downloadTask(with: llmDownloadUrl) + if case let .directory(url) = modelConfiguration.id { + Task { @MainActor in + self.state = .downloaded(storageUrl: url) + } + return + } - downloadTask?.resume() + downloadTask?.cancel() + downloadTask = Task(priority: .userInitiated) { + do { + _ = try await loadModelContainer(configuration: modelConfiguration) { progress in + Task { @MainActor in + self.state = .downloading(progress: progress) + } + } + + Task { @MainActor in + self.state = .downloaded(storageUrl: modelConfiguration.modelDirectory()) + } + } catch { + Task { @MainActor in + self.state = .error( + AnyLocalizedError( + error: error, + defaultErrorDescription: LocalizedStringResource("LLM_DOWNLOAD_FAILED_ERROR", bundle: .atURL(from: .module)) + ) + ) + } + } + } } /// Cancels the download of a specified model via a `URLSessionDownloadTask`. diff --git a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManagerDelegate.swift b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManagerDelegate.swift deleted file mode 100644 index 780dc05..0000000 --- a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManagerDelegate.swift +++ /dev/null @@ -1,86 +0,0 @@ -// -// This source file is part of the Stanford Spezi open source project -// -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) -// -// SPDX-License-Identifier: MIT -// - -import Foundation -import os -import SpeziViews - - -/// Delegate of the ``LLMLocalDownloadManager`` implementing the methods of the`URLSessionDownloadDelegate` conformance. -class LLMLocalDownloadManagerDelegate: NSObject, URLSessionDownloadDelegate { - /// A Swift `Logger` that logs important information from the `LocalLLMDownloadManager`. - private static let logger = Logger(subsystem: "edu.stanford.spezi", category: "SpeziLLM") - /// A `weak` reference to the ``LLMLocalDownloadManager``. - private weak var manager: LLMLocalDownloadManager? - /// The storage location `URL` of the downloaded LLM. - private let storageUrl: URL - - - /// Creates a new `LLMLocalDownloadManagerDelegate` - /// - Parameters: - /// - manager: The ``LLMLocalDownloadManager`` from which the `LLMLocalDownloadManagerDelegate` is initialized. - /// - storageUrl: The `URL` where the downloaded LLM should be stored. - init(manager: LLMLocalDownloadManager, storageUrl: URL) { - self.manager = manager - self.storageUrl = storageUrl - } - - - /// Indicates the progress of the current model download. - func urlSession( - _ session: URLSession, - downloadTask: URLSessionDownloadTask, - didWriteData bytesWritten: Int64, - totalBytesWritten: Int64, - totalBytesExpectedToWrite: Int64 - ) { - let progress = Double(totalBytesWritten) / Double(totalBytesExpectedToWrite) * 100 - Task { @MainActor in - self.manager?.state = .downloading(progress: progress) - } - } - - /// Indicates the completion of the model download including the downloaded file `URL`. - func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { - do { - _ = try FileManager.default.replaceItemAt(self.storageUrl, withItemAt: location) - Task { @MainActor in - self.manager?.state = .downloaded(storageUrl: self.storageUrl) - } - } catch { - Task { @MainActor in - self.manager?.state = .error( - AnyLocalizedError( - error: error, - defaultErrorDescription: - LocalizedStringResource("LLM_DOWNLOAD_FAILED_ERROR", bundle: .atURL(from: .module)) - ) - ) - } - Self.logger.error("\(String(describing: error))") - } - } - - /// Indicates an error during the model download - func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - // The `error` property is set for client-side errors (e.g. couldn't resolve host name), - // the `task.error` property is set in the case of server-side errors. - // If none of these properties are set, no error has occurred. - if let error = error ?? task.error { - Task { @MainActor in - self.manager?.state = .error( - AnyLocalizedError( - error: error, - defaultErrorDescription: LocalizedStringResource("LLM_DOWNLOAD_FAILED_ERROR", bundle: .atURL(from: .module)) - ) - ) - } - Self.logger.error("\(String(describing: error))") - } - } -} diff --git a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift index 2a2a091..f1eec99 100644 --- a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift +++ b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift @@ -1,11 +1,13 @@ // // This source file is part of the Stanford Spezi open source project // -// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md) +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) // // SPDX-License-Identifier: MIT // +import MLXLLM +import SpeziLLMLocal import SpeziOnboarding import SpeziViews import SwiftUI @@ -53,7 +55,6 @@ public struct LLMLocalDownloadView: View { private let action: () async throws -> Void /// Description of the to-be-downloaded model shown in the ``LLMLocalDownloadView``. private let downloadDescription: Text - /// Indicates the state of the view, get's derived from the ``LLMLocalDownloadManager/state``. @State private var viewState: ViewState = .idle @@ -138,7 +139,7 @@ public struct LLMLocalDownloadView: View { .progressViewStyle(LinearProgressViewStyle()) .padding() - Text("Downloaded \(String(format: "%.2f", downloadProgress))% of 100%.", bundle: .module) + Text("Downloaded \(String(format: "%.0f", downloadProgress))% of 100%.", bundle: .module) .padding(.top, 5) } } @@ -155,7 +156,7 @@ public struct LLMLocalDownloadView: View { /// Represents the download progress of the model in percent (from 0 to 100) @MainActor private var downloadProgress: Double { if case .downloading(let progress) = self.downloadManager.state { - return progress + return progress.fractionCompleted * 100 } else if case .downloaded = self.downloadManager.state { return 100.0 } @@ -165,9 +166,7 @@ public struct LLMLocalDownloadView: View { /// A `Bool` flag indicating if the model already exists on the device private var modelExists: Bool { - FileManager.default.fileExists( - atPath: self.downloadManager.llmStorageUrl.path() - ) + self.downloadManager.modelExists } @@ -179,16 +178,12 @@ public struct LLMLocalDownloadView: View { /// - llmDownloadLocation: The local `URL` where the LLM file should be stored. /// - action: The action that should be performed when pressing the primary button of the view. public init( + model: LLMLocalModel, downloadDescription: LocalizedStringResource, - llmDownloadUrl: URL = LLMLocalDownloadManager.LLMUrlDefaults.llama2ChatModelUrl, - llmStorageUrl: URL = .cachesDirectory.appending(path: "llm.gguf"), action: @escaping () async throws -> Void ) { self._downloadManager = State( - wrappedValue: LLMLocalDownloadManager( - llmDownloadUrl: llmDownloadUrl, - llmStorageUrl: llmStorageUrl - ) + wrappedValue: LLMLocalDownloadManager(model: model) ) self.downloadDescription = Text(downloadDescription) self.action = action @@ -203,16 +198,12 @@ public struct LLMLocalDownloadView: View { /// - action: The action that should be performed when pressing the primary button of the view. @_disfavoredOverload public init( + model: LLMLocalModel, downloadDescription: S, - llmDownloadUrl: URL = LLMLocalDownloadManager.LLMUrlDefaults.llama2ChatModelUrl, - llmStorageUrl: URL = .cachesDirectory.appending(path: "llm.gguf"), action: @escaping () async throws -> Void ) { self._downloadManager = State( - wrappedValue: LLMLocalDownloadManager( - llmDownloadUrl: llmDownloadUrl, - llmStorageUrl: llmStorageUrl - ) + wrappedValue: LLMLocalDownloadManager(model: model) ) self.downloadDescription = Text(verbatim: String(downloadDescription)) self.action = action @@ -223,6 +214,7 @@ public struct LLMLocalDownloadView: View { #if DEBUG #Preview { LLMLocalDownloadView( + model: .phi3_4bit, downloadDescription: "LLM_DOWNLOAD_DESCRIPTION".localized(.module), action: {} ) diff --git a/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift b/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift index 34b5797..ad6999c 100644 --- a/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift +++ b/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift @@ -25,18 +25,16 @@ struct LLMLocalChatTestView: View { } else { LLMChatViewSchema( with: LLMLocalSchema( - modelPath: .cachesDirectory.appending(path: "llm.gguf"), - parameters: .init(maxOutputLength: 512), - contextParameters: .init(contextWindowSize: 1024), + model: .llama3_8B_4bit, formatChat: LLMLocalSchema.PromptFormattingDefaults.llama3 ) ) } } - .navigationTitle("LLM_LOCAL_CHAT_VIEW_TITLE") + .navigationTitle("LLM_LOCAL_CHAT_VIEW_TITLE") } - - + + init(mockMode: Bool = false) { self.mockMode = mockMode } @@ -48,10 +46,10 @@ struct LLMLocalChatTestView: View { NavigationStack { LLMLocalChatTestView(mockMode: true) } - .previewWith { - LLMRunner { - LLMMockPlatform() - } + .previewWith { + LLMRunner { + LLMMockPlatform() } + } } #endif diff --git a/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift b/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift index dfbcb4d..37bbe71 100644 --- a/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift +++ b/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift @@ -19,11 +19,10 @@ struct LLMLocalOnboardingDownloadView: View { var body: some View { LLMLocalDownloadView( + model: .phi3_4bit, downloadDescription: "LLM_DOWNLOAD_DESCRIPTION", - llmDownloadUrl: LLMLocalDownloadManager.LLMUrlDefaults.llama3InstructModelUrl /// By default, download the Llama3 model - ) { - onboardingNavigationPath.nextStep() - } + action: onboardingNavigationPath.nextStep + ) } } diff --git a/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingFlow.swift b/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingFlow.swift index f0a52ac..104392e 100644 --- a/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingFlow.swift +++ b/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingFlow.swift @@ -23,7 +23,7 @@ struct LLMLocalOnboardingFlow: View { LLMLocalOnboardingDownloadView() } } - .interactiveDismissDisabled(!completedOnboardingFlow) + .interactiveDismissDisabled(!completedOnboardingFlow) } } diff --git a/Tests/UITests/TestApp/TestAppDelegate.swift b/Tests/UITests/TestApp/TestAppDelegate.swift index 2f4d84a..f4cc636 100644 --- a/Tests/UITests/TestApp/TestAppDelegate.swift +++ b/Tests/UITests/TestApp/TestAppDelegate.swift @@ -46,7 +46,6 @@ class TestAppDelegate: SpeziAppDelegate { LLMRunner { LLMMockPlatform() - LLMLocalPlatform() // No CA certificate (meaning no encrypted traffic) for development purposes, see `caCertificateUrl` above LLMFogPlatform(configuration: .init(host: "spezillmfog.local", caCertificate: nil)) LLMOpenAIPlatform() diff --git a/Tests/UITests/TestAppUITests/TestAppLLMLocalUITests.swift b/Tests/UITests/TestAppUITests/TestAppLLMLocalUITests.swift index d089935..09a0084 100644 --- a/Tests/UITests/TestAppUITests/TestAppLLMLocalUITests.swift +++ b/Tests/UITests/TestAppUITests/TestAppLLMLocalUITests.swift @@ -43,8 +43,22 @@ class TestAppLLMLocalUITests: XCTestCase { sleep(1) // Chat + let inputTextfield = app.textViews["Message Input Textfield"] + XCTAssertTrue(inputTextfield.exists) + + #if !os(macOS) - try app.textViews["Message Input Textfield"].enter(value: "New Message!", options: [.disableKeyboardDismiss]) + if UIDevice.current.userInterfaceIdiom == .pad { + #if RELEASE + throw XCTSkip("Skipped on iPad, see: https://github.com/StanfordBDHG/XCTestExtensions/issues/27") + #endif + + inputTextfield.tap() + sleep(1) + inputTextfield.typeText("New Message!") + } else { + try inputTextfield.enter(value: "New Message!", options: [.disableKeyboardDismiss]) + } #else try app.textFields["Message Input Textfield"].enter(value: "New Message!", options: [.disableKeyboardDismiss]) #endif diff --git a/Tests/UITests/TestAppUITests/TestAppLLMOpenAIUITests.swift b/Tests/UITests/TestAppUITests/TestAppLLMOpenAIUITests.swift index 134ad9e..8e08625 100644 --- a/Tests/UITests/TestAppUITests/TestAppLLMOpenAIUITests.swift +++ b/Tests/UITests/TestAppUITests/TestAppLLMOpenAIUITests.swift @@ -29,6 +29,10 @@ class TestAppLLMOpenAIUITests: XCTestCase { func testSpeziLLMOpenAIOnboarding() throws { // swiftlint:disable:this function_body_length let app = XCUIApplication() + if UIDevice.current.userInterfaceIdiom == .pad { + throw XCTSkip("Skipped on iPad, see: https://github.com/StanfordBDHG/XCTestExtensions/issues/27") + } + XCTAssert(app.buttons["LLMOpenAI"].waitForExistence(timeout: 2)) app.buttons["LLMOpenAI"].tap() @@ -141,6 +145,10 @@ class TestAppLLMOpenAIUITests: XCTestCase { func testSpeziLLMOpenAIChat() throws { let app = XCUIApplication() + if UIDevice.current.userInterfaceIdiom == .pad { + throw XCTSkip("Skipped on iPad, see: https://github.com/StanfordBDHG/XCTestExtensions/issues/27") + } + XCTAssert(app.buttons["LLMOpenAI"].waitForExistence(timeout: 2)) app.buttons["LLMOpenAI"].tap() diff --git a/Tests/UITests/UITests.xcodeproj/project.pbxproj b/Tests/UITests/UITests.xcodeproj/project.pbxproj index e271b19..020fccc 100644 --- a/Tests/UITests/UITests.xcodeproj/project.pbxproj +++ b/Tests/UITests/UITests.xcodeproj/project.pbxproj @@ -605,7 +605,6 @@ SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_OBJC_INTEROP_MODE = objcxx; SWIFT_STRICT_CONCURRENCY = complete; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; @@ -648,7 +647,6 @@ SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_OBJC_INTEROP_MODE = objcxx; SWIFT_STRICT_CONCURRENCY = complete; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7";