Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AICI to use WebAssembly Component Model #84

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

AaronFriel
Copy link

@AaronFriel AaronFriel commented Mar 30, 2024

This is a significant change to the AICI Runtime (host) and AICI Controller (guest) to use WASI components. As part of this change, a significant amount of unsafe code is removed and the protocol is simplified to remove the need for "BLOB" types and side-channels.

The protocol is documented in wit/controller.wit, and a minimal WASI runtime is provided to AI CI controllers.

Some notes:

  • The AICI runtime no longer directly reads and writes to the guest's memory. Instead, the guest provides a Runner resource (using WebAssembly Component terminology), which exposes the low-level protocol to the host as a constructor and trait with methods.
  • The Blob protocols are removed entirely, replaced by the Runner resource. This and other side-channels for communicating with the runtime, e.g. allowed tokens (logit biases) outside of MidProcessResult, are removed.
  • The (Variable) Storage and Tokenizer protocols are separate WebAssembly Components, which can be versioned independently of the runtime.
  • Types are changed to be consistent with the WebAssembly interface, e.g.: SeqId is used in far more places to avoid casts.

@AaronFriel
Copy link
Author

AaronFriel commented Mar 30, 2024

@mmoskal There are still some a few TODOs here I believe:

  • Logging (stdout/stderr?) from the controllers
  • Handling errors, panics
  • Update: No change vs upstream. Clock - I see a SPECTRE/MELTDOWN mitigation in the upstream main, not sure if it will be straightforward to override when using the WASI runtime's builtins.
  • Update: fixed. Build errors after rebasing.

@AaronFriel
Copy link
Author

PR updated with some refactors, among which is an ergonomic improvement to exporting guests that fixes running cargo build from the workspace root.

It seemed to me that until the WASI Preview 2 target fully lands, the controllers may need to be built as libraries with type cdylib, though I couldn't find anything definitive. Between that and some of the machinery used by export!(), compiling those crates for, e.g., linux-x86-64, would error with cc.

The improved export macro hides the machinery:

#[macro_export]
macro_rules! export {
($ty:ident) => {
#[doc(hidden)]
#[cfg(target_arch = "wasm32")]
$crate::bindings::export!($ty with_types_in $crate::bindings);
};
}

@squillace
Copy link

Hi @AaronFriel, I LOOOOOVVVVEEEEE this. My team does a bunch of the infrastructure work upstream supporting wasm components, and I'd like to see how to help bring this in to the project. 🖖

@emrekiciman
Copy link
Collaborator

Thanks very much for this PR, @AaronFriel ! And thanks @squillace for helping review!

@squillace
Copy link

@AaronFriel don't fret, we'll get here. You submitted when we had KubeCon followed by easter followed by the heavens being swallowed by the moon. People are returning end of this week....

@AaronFriel
Copy link
Author

@squillace I'm in no rush, and pleased to see your review when you're able!

Sorry if I did anything to nag you - I don't think I triggered anything on my end since posting the PR?

@squillace
Copy link

nope, just don't like not communicating n prs when someone is trying to help do the right thing.

@mmoskal
Copy link
Member

mmoskal commented Apr 10, 2024

This looks great, from my non-very-well-informed POV!

Unfortunately, I'm in the middle of some work items that may affect this. In particular, I'm dropping the pre/post callbacks and only leaving the mid callback. It looks like we would be unable to run the pre/post fast enough, especially with speculative decoding (I have not considered that in the past).

I also want to support native controllers which is probably relevant here.

This may take a few weeks to finish and is quite high priority for us here.

@AaronFriel
Copy link
Author

The 1 token penalty in #68 seems very reasonable for the capabilities offered in AICI. I'm not intimately familiar with the workings of the rLLM implementations, beyond what was necessary for this PR, but from your notes it sounds like blocking the LLM holds up an entire batch, effectively a pipeline stall?

@mmoskal
Copy link
Member

mmoskal commented Apr 10, 2024

In non-speculative implementations, the pre/post happens on the "critical path" of sampling after we got the logits from the GPU but before we started the next step on the GPU (the next step needs the sampled tokens from the current step). Thus, the GPU sits idle while we hold the entire batch.

Now, in principle it would be possible to be working on two sets of sequences and swap them (running pre/post on one, while the other is already computing logits on t he GPU and vice versa). The problem with this is that it needs 2x memory for KV cache which is typically the limiting factor for batch size and thus throughput. It may be possible for the draft model though in case it's too fast even with the new only-mid strategy.

@AaronFriel
Copy link
Author

AaronFriel commented Apr 10, 2024

Does vLLM have a mechanism for adding and removing sequences from batches, or would it be simpler in AICI to effectively Promise.race() the WASM, and if mid_process exceeds the deadline, to treat it as a de-facto fork?

That is, never allow AICI to block the LLM, but you might generate logits you throw away. In those situations, backtrack and resume?

Taking this to #68 though, because it sounds like this PR is blocked on understanding that discussion.

This is a significant change to the AICI Runtime (host) and AICI Controller
(guest) to use WASI components. As part of this change, a significant amount of
unsafe code is removed and the protocol is simplified to remove the need for
"BLOB" types and side-channels.

The protocol is documented in `wit/controller.wit`, and a minimal WASI runtime
is provided to AI CI controllers.

Some notes:
* The AICI runtime now longer directly reads and writes to the guest's memory.
  Instead, the guest provides a `Runner` resource (using WebAssembly Component
  terminology), which exposes the low-level protocol to the host as a
  constructor and trait with methods.
* The Blob protocols are removed entirely, replaced by the `Runner` resource.
  This and other side-channels for communicating with the runtime, e.g. allowed
  tokens (logit biases) outside of `MidProcessResult`, are removed.
* The (Variable) Storage and Tokenizer protocols are separate WebAssembly
  Components, which can be versioned independently of the runtime.
* Types are changed to be consistent with the WebAssembly interface, e.g.:
  `SeqId` is used in far more places to avoid casts.
@AaronFriel
Copy link
Author

@mmoskal I'm most of the way through rebasing on the latest changes, though I think it'd be good for us to chat some time (I'll follow up on our email) about whether this would be acceptable in the near term. Personally, I'm very excited about this because I want to explore alternatives to the current API design(s) for LLMs and more powerful protocols than the request-response pattern modern LLMs have. Sadly, it is a lot of work to rebase on the frequent protocol changes in this repo, and I haven't been able to make much progress.

I would like to propose and discuss a couple things:

  1. Using cfg flags to enable wasip2 and the current wasip1 host/guest component protocol live side-by-side. (Conveniently, we can call this support for AICI modules and support for AICI components.)

  2. Making wit the IDL of choice for designing protocol changes. With this change protocol is richer and much easier to understand than the side-effect driven protocol of calling return_logit_bias, or the various blob and variable storage APIs. I think this work in progress PR for example solves this panic by using a rich type for branches:

    b.map_mask(|vob| {
    if used_logits {
    panic!("aici_mid_process: multiple branches with sampling not yet supported");
    }
    used_logits = true;
    host::return_logit_bias(&vob) as usize
    })

  3. Better understanding where the performance critical portions of the code are. The use of shared memory and unsafe operations gives me pause - which ones are performance critical and which ones are shaving off nanoseconds while where are milliseconds left on the table? E.g.: the critical path for mid_process always involves a serde.

    aici/aicirt/src/main.rs

    Lines 937 to 939 in eec6e52

    Some("mid_process") => Ok(serde_json::to_value(
    &self.aici_mid_process(serde_json::from_value(json)?)?,
    )?),

@squillace
Copy link

@AaronFriel this is amazing work! I wanna tag @yoshuawuyts here for component and rust expertise and @devigned to track the work for usage. I'd LOOOOOOOVVVEEE to try this out.

@AaronFriel
Copy link
Author

AaronFriel commented Sep 9, 2024

OK, updated the PR to check off all of the TODOs.

SPECTRE mitigation via a bounded monotonic clock:

impl HostMonotonicClock for BoundedResolutionClock {
fn resolution(&self) -> u64 {
self.resolution.as_nanos() as u64
}
fn now(&self) -> u64 {
let now = std::time::Instant::now();
let nanos = now.duration_since(self.initial).as_nanos() as u64;
let res = self.resolution.as_nanos() as u64;
let nanos = if res > 0 { nanos / res * res } else { nanos };
nanos as u64
}
}
impl HostWallClock for BoundedResolutionClock {
fn resolution(&self) -> Duration {
self.resolution
}
fn now(&self) -> Duration {
let now = std::time::SystemTime::now();
let nanos = now
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
let res = self.resolution.as_nanos() as u64;
let nanos = if res > 0 { nanos / res * res } else { nanos };
Duration::from_nanos(nanos)
}
}

And logging is wired up again to an in memory output pipe with bounded capacity:

let wall_clock =
BoundedResolutionClock::new(Duration::from_nanos(limits.timer_resolution_ns));
let monotonic_clock = wall_clock.clone();
let log = wasmtime_wasi::pipe::MemoryOutputPipe::new(MAX_LOG);
let stdout = log.clone();
let stderr = log.clone();
ModuleData {
id,
log: log,
printed_log: 0,
globals,
group_channel,
store_limits,
storage_log: Vec::new(),
wasi_ctx: wasmtime_wasi::WasiCtxBuilder::new()
.wall_clock(wall_clock)
.monotonic_clock(monotonic_clock)
.stdout(stdout)
.stderr(stderr)
.build(),
resource_table: wasmtime_wasi::ResourceTable::new(),
}

With these changes, this works as expected:

$ ./server.sh --trace-rt phi2

And

$ ./aici.sh build controllers/pyctrl --tag pyctrl
...
$ ./aici.sh run ./scripts/list-of-five.py --ctrl pyctrl
[0]: FIXED "What‧ are‧ the‧ most‧ popular‧ types‧ of‧ vehicles‧?‧\n"
[0]: FIXED "1‧."
[0]: 
[0]: 
[0]: GEN: " Cars‧\n"
[0]: FIXED "2‧."
[0]: 
[0]: 
[0]: 
[0]: GEN: " B‧uses‧\n"
[0]: FIXED "3‧."
[0]: 
[0]: 
[0]: 
[0]: GEN: " Motor‧cycles‧\n"
[0]: FIXED "4‧."
[0]: 
[0]: 
[0]: 
[0]: GEN: " Tru‧cks‧\n"
[0]: FIXED "5‧."
[0]: 
[0]: 
[0]: 
[0]: 
[0]: GEN: " B‧icy‧cles‧\n"
[0]: FIXED "\n"
[0]: 
[0]: 
[DONE]
[Response]  What are the most popular types of vehicles?
1. Cars
2. Buses
3. Motorcycles
4. Trucks
5. Bicycles



response saved to tmp/response.json
Usage: {'sampled_tokens': 22, 'ff_tokens': 37, 'cost': 81}
Timing: {'http_response': 0.07584023475646973, 'data0': 0.07589316368103027, 'first_token': 0.12601089477539062, 'last_token': 1.248687505722046}
Tokens/sec: {'prompt': 19.93196819860192, 'sampling': 17.618499343659753}
Storage: {'result': '1. Cars\n2. Buses\n3. Motorcycles\n4. Trucks\n5. Bicycles\n\n'}

@AaronFriel
Copy link
Author

Performance, at least for this example, looks to be equal or better per timers log output.

Components branch:

INFO [rllm_llamacpp::llamacpp::tmodel] model forward: step #120 55.31ms; 1 tok(s); 18.1tps
DEBUG [rllm::engine] timers
     56.238ms (x20) step
       0.1%     0.030ms (x20) .schedule
       0.0%     0.011ms (x20) .aici_mid
      99.9%    56.173ms (x20) .run_model
          99.6%    55.974ms (x20) .model_fwd
           0.2%     0.139ms (x20) .sample
              29.0%     0.040ms (x20) .aici_bias
              14.9%     0.030ms (x14) .sample

Main branch:

INFO [rllm_llamacpp::llamacpp::tmodel] model forward: step #60 57.65ms; 1 tok(s); 17.3tps
DEBUG [rllm::engine] timers
     55.933ms (x20) step
       0.1%     0.032ms (x20) .schedule
       0.0%     0.012ms (x20) .aici_mid
      99.9%    55.863ms (x20) .run_model
          99.7%    58.612ms (x19) .model_fwd
           0.2%     0.131ms (x19) .sample
              32.0%     0.042ms (x19) .aici_bias
              13.9%     0.029ms (x12) .sample

With aici_mid within measurement error.

@squillace
Copy link

I really can't wait to try this.

@mmoskal
Copy link
Member

mmoskal commented Sep 10, 2024

This is great! I love the conciseness of the interface description.

I have not had much time to work on AICI lately, focusing on the specific llguidance controller (which is mostly being run natively, but with similar interface).

Just as a general heads up - the problem with run into with AICI in production is the case where there are more sequences in batch (and thus parallel controller processes) than cores. This is because I spin for a while on futexes (to minimize latency), and this kills performance when we're out of cores. This would need to be fixed somehow. The latency minimization was mostly there when we still had post/pre_process(); for mid_process() it shouldn't matter that much.

@AaronFriel
Copy link
Author

I wonder if the streaming protocol of WASI helps here - instead of using a futex using IPC with efficient reading and writing to shared circular buffers?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants