|
481 | 481 | " ):\n",
|
482 | 482 | " super().__init__()\n",
|
483 | 483 | " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
484 |
| - " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", |
| 484 | + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\" # NEW\n", |
485 | 485 | "\n",
|
486 | 486 | " self.d_out = d_out\n",
|
487 | 487 | " self.num_heads = num_heads\n",
|
|
886 | 886 | " \"n_heads\": 32, # Number of attention heads\n",
|
887 | 887 | " \"n_layers\": 32, # Number of layers\n",
|
888 | 888 | " \"hidden_dim\": 11_008, # Size of the intermediate dimension in FeedForward\n",
|
889 |
| - " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", |
| 889 | + " \"dtype\": torch.bfloat16 # Lower-precision dtype to reduce memory usage\n", |
890 | 890 | "}"
|
891 | 891 | ]
|
892 | 892 | },
|
|
909 | 909 | " \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n",
|
910 | 910 | " \"rope_base\": 500_000.0, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
|
911 | 911 | " \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n",
|
912 |
| - " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", |
| 912 | + " \"dtype\": torch.bfloat16 # Lower-precision dtype to reduce memory usage\n", |
913 | 913 | "}"
|
914 | 914 | ]
|
915 | 915 | },
|
|
2062 | 2062 | " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
2063 | 2063 | " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
2064 | 2064 | " \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n",
|
2065 |
| - " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", |
| 2065 | + " \"dtype\": torch.bfloat16 # Lower-precision dtype to reduce memory usage\n", |
2066 | 2066 | "}\n",
|
2067 | 2067 | "\n",
|
2068 | 2068 | "LLAMA31_CONFIG_8B = {\n",
|
|
2074 | 2074 | " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
2075 | 2075 | " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
2076 | 2076 | " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
2077 |
| - " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", |
| 2077 | + " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", |
2078 | 2078 | " \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
2079 | 2079 | " \"factor\": 8.0,\n",
|
2080 | 2080 | " \"low_freq_factor\": 1.0,\n",
|
|
2448 | 2448 | " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
2449 | 2449 | " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
2450 | 2450 | " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
2451 |
| - " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", |
| 2451 | + " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usagey\n", |
2452 | 2452 | " \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
2453 | 2453 | " \"factor\": 8.0,\n",
|
2454 | 2454 | " \"low_freq_factor\": 1.0,\n",
|
|
2467 | 2467 | " \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
|
2468 | 2468 | " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
2469 | 2469 | " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
2470 |
| - " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", |
| 2470 | + " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", |
2471 | 2471 | " \"rope_freq\": { # RoPE frequency scaling\n",
|
2472 | 2472 | " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
|
2473 | 2473 | " \"low_freq_factor\": 1.0,\n",
|
|
0 commit comments