-
Notifications
You must be signed in to change notification settings - Fork 239
build bitnet from HF bf16 model #1421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
| # Make MatMul node (output projection weight node) | ||
| o_proj = 'o_proj' if hasattr(attention, 'o_proj') else 'dense' | ||
| o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" | ||
| o_weight = eval(f"attention.{o_proj}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid eval: it is a code smell
| o_weight = eval(f"attention.{o_proj}") | |
| o_weight = getattr(attention, o_proj) |
| o_bias_exists = eval(f"attention.{o_proj}.bias") is not None | ||
| if o_bias_exists: | ||
| o_add_name = f"/model/layers.{layer_id}/attn/o_proj/Add" | ||
| o_bias = eval(f"attention.{o_proj}.bias.detach().numpy()") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: here as well
| cos_cache=cos_cache_name, sin_cache=sin_cache_name, **kwargs, | ||
| ) | ||
|
|
||
| # add an extra SimplifiedLayerNorm before the output projection for attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified by adding a setting in the attention_attrs for the RMSNorm before the output projection MatMul.
onnxruntime-genai/src/python/py/models/builder.py
Lines 1619 to 1621 in 36cd2ca
| # Make Q/K SimplifiedLayerNorm nodes | |
| if self.attention_attrs["q_norm"] and self.attention_attrs["k_norm"]: | |
| self.make_qk_norm(layer_id, attention) |
For example:
"q_norm": False, # LayerNorm after MatMul in Q path
"k_norm": False, # LayerNorm after MatMul in K path
"o_norm": False, # LayerNorm before MatMul in output pathThen we can set o_norm = True in the BitNetModel class constructor and insert the following logic here.
# Make SimplifiedLayerNorm node before output MatMul
if self.attention_attrs["o_norm"]:
self.make_o_norm(layer_id, attention) Once that's done, we can remove this code to override the make_attention method.
| """ | ||
| if model_type == "ChatGLMModel": | ||
| if model_type == "BitNetForCausalLM": | ||
| model = GGUFModel(input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does BitNet not require any post-processing (e.g. undo_permute, swap_norm_types, swap_mlp_types) to match the PyTorch model's class attributes?
| super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) | ||
| self.rms_norm_eps = config.rms_norm_eps | ||
|
|
||
| def make_mlp_proj(self, layer_id, mlp, root_input): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BitNet uses three MatMuls (gate_proj, up_proj, and down_proj) but this method only creates two of them (up_proj and down_proj). You would have to override the base make_mlp_proj method. This version of make_mlp_proj is specific to the Nemotron model because Nemotron does not have a gate projection MatMul.
|
|
||
| act_fn_name = self.make_activation(layer_id, root_input=f"{up_name}/output_0") | ||
|
|
||
| # add an extra SimplifiedLayerNorm after the MLP activation before the down projection MatMul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the make_attention method, it would be easier to add a setting in mlp_attrs for the extra RMSNorm before the down projection MatMul.
| } | ||
|
|
||
| std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, std::unique_ptr<Config> config) { | ||
| std::set<std::string> llm_types = {"chatglm", "decoder", "gemma", "gemma2", "gemma3_text", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add bitnet in llm_types so that the alphabetical order is maintained.
| elif model_type == "gemma3_text": | ||
| args.chat_template = '<start_of_turn>user\n{system_prompt}{input}<end_of_turn>\n<start_of_turn>model\n' | ||
| elif model_type.startswith("bitnet"): | ||
| # args.chat_template = '{system_prompt}{"role": "user", "content": "{input}"}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add the chat template in the model-chat.py example as well?
| elif model_type.startswith("llama"): | ||
| system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>" | ||
| print("Using System Prompt for LLAMA 3, if you are using LLAMA 2 please pass the argument --system_prompt '<s>[INST] <<SYS>>\\n{args.system_prompt}\\n<</SYS>>')") | ||
| elif model_type.startswith("bitnet"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add the system prompt in the model-chat.py example as well?
|
The PR to add bfloat16 support in the model builder has been opened here. Once it is merged, you can target your PR to merge with the main branch instead. |
Signed-off-by: Liqun Fu <[email protected]>
|
closing this. please reopen if still relevant. |
No description provided.