Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for GPTJ architecture #3012

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

tlopex
Copy link
Contributor

@tlopex tlopex commented Nov 4, 2024

This PR supports GPTJ architecture.

The model conversation demonstration is here:

tlopex@tlopex-OMEN-by-HP-Laptop-17-ck1xxx:~/mlc-llm$ mlc_llm chat dist/gpt-j-6b-q4f16_1-MLC              --device "cuda:0" --overrides context_window_size=2048              --model ./dist/libs/gpt-j-6b-q4f16_1-cuda.so
[2024-11-04 21:35:57] INFO auto_device.py:79: Found device: cuda:0
[2024-11-04 21:35:57] INFO engine_base.py:143: Using library model: ./dist/libs/gpt-j-6b-q4f16_1-cuda.so
[21:35:58] /home/tlopex/mlc-llm/cpp/serve/config.cc:688: Under mode "local", max batch size will be set to 4, max KV cache token capacity will be set to 2048, prefill chunk size will be set to 2048. 
[21:35:58] /home/tlopex/mlc-llm/cpp/serve/config.cc:688: Under mode "interactive", max batch size will be set to 1, max KV cache token capacity will be set to 2048, prefill chunk size will be set to 2048. 
[21:35:58] /home/tlopex/mlc-llm/cpp/serve/config.cc:688: Under mode "server", max batch size will be set to 128, max KV cache token capacity will be set to 20800, prefill chunk size will be set to 2048. 
[21:35:58] /home/tlopex/mlc-llm/cpp/serve/config.cc:769: The actual engine mode is "interactive". So max batch size is 1, max KV cache token capacity is 2048, prefill chunk size is 2048.
[21:35:58] /home/tlopex/mlc-llm/cpp/serve/config.cc:774: Estimated total single GPU memory usage: 5395.686 MB (Parameters: 3247.127 MB. KVCache: 1008.268 MB. Temporary buffer: 1140.291 MB). The actual usage might be slightly larger than the estimated number.
You can use the following special commands:
  /help               print the special commands
  /exit               quit the cli
  /stats              print out stats of last request (token/sec)
  /metrics            print out full engine metrics
  /reset              restart a fresh chat
  /set [overrides]    override settings in the generation config. For example,
                      `/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`
                      Note: Separate stop words in the `stop` option with commas (,).
  Multi-line input: Use escape+enter to start a new line.

>>> hi
How may I help you?

I found that I have to change code in position_embedding of relax to run locally. I wonder if I still need a update there.

@MasterJH5574
Copy link
Member

@tlopex Thanks! Do you mind fixing the lint errors as shown in CI?

@tlopex
Copy link
Contributor Author

tlopex commented Nov 5, 2024

@MasterJH5574 Sorry for being late. I thought I solved the lint issue yesterday.
Now there seems something wrong with Model Compilation,

[2024-11-05 10:28:28] INFO compile.py:185: Registering metadata: {'model_type': 'gptj', 'quantization': 'q4f32_1', 'context_window_size': 2048, 'sliding_window_size': -1, 'attention_sink_size': -1, 'prefill_chunk_size': 2048, 'tensor_parallel_shards': 1, 'pipeline_parallel_stages': 1, 'kv_state_kind': 'kv_cache', 'max_batch_size': 1}

error: Unsupported RoPE scaling type: gptj

 --> /Users/catalyst/Workspace/miniforge3/envs/mlc-llm-ci/lib/python3.8/site-packages/tvm/relax/frontend/nn/llm/kv_cache.py:708:53

     |  

 708 |                                                      _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),

     |                                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 

Compiling with arguments:

  --config          GPTJConfig(vocab_size=50400, n_embd=4096, n_layer=28, n_head=16, layer_norm_epsilon=1e-05, rotary_dim=64, activation_function='gelu_new', n_inner=None, rope_scaling={'rope_type': 'gptj'}, context_window_size=2048, prefill_chunk_size=2048, tensor_parallel_shards=1, max_batch_size=1, head_dim=0, kwargs={})

  --quantization    GroupQuantize(name='q4f32_1', kind='group-quant', group_size=32, quantize_dtype='int4', storage_dtype='uint32', model_dtype='float32', linear_weight_layout='NK', quantize_embedding=True, quantize_final_fc=True, num_elem_per_storage=8, num_storage_per_group=4, max_int_value=7, tensor_parallel_shards=0)

  --model-type      gptj

  --target          {"thread_warp_size": runtime.BoxInt(32), "host": {"mtriple": "arm64-apple-darwin22.1.0", "tag": "", "kind": "llvm", "mcpu": "apple-m1", "keys": ["arm_cpu", "cpu"]}, "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]}

  --opt             flashinfer=0;cublas_gemm=0;faster_transformer=0;cudagraph=0;cutlass=0;ipc_allreduce_strategy=NONE

  --system-lib-prefix ""

  --output          /var/folders/n1/5d_r6z251v39vwpj8hj_z1vc0000gp/T/tmpl4pq_51h/lib328.dylib

  --overrides       context_window_size=None;sliding_window_size=None;prefill_chunk_size=None;attention_sink_size=None;max_batch_size=None;tensor_parallel_shards=1;pipeline_parallel_stages=None

note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.

[10:28:28] /Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!

It is the same problem I met in my own device when I did not update position_embedding in tvm. So I think maybe I need to pull a request there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants