-
Notifications
You must be signed in to change notification settings - Fork 438
Add KTO Loss #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Add KTO Loss #475
Changes from 7 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
4471ba6
Add KTO Loss
hebiao064 d3f565b
Fix Tests
hebiao064 ab08ab6
formatting
hebiao064 98aa519
Add more docstrings
hebiao064 6e67869
Fix tests
hebiao064 3a76c76
Add Benchmark Result
hebiao064 7114a19
Merge branch 'main' into kto_loss
hebiao064 f46a595
Merge branch 'main' into kto_loss
hebiao064 1992153
Add KL, Unpair flag, Preference Labels
hebiao064 3aebf8d
Reorder Preference Labels and Bias
hebiao064 1dd10a7
Fix Tests
hebiao064 130e909
Fix all tests
hebiao064 c2fab1d
make it random number for chosen
hebiao064 fe7eafc
Merge branch 'main' into kto_loss
hebiao064 eeaa570
Update benchmark
hebiao064 85a582e
speed up kto loss and some refactor (#495)
shivam15s 6d50d44
Merge branch 'main' into kto_loss
hebiao064 846dc2e
Change sign of loss to align with merged changes
hebiao064 33fa548
Add KL into KTO Test
hebiao064 f7b29d5
Add KL and Benchmark
hebiao064 29e818e
Fix the speed slow down by removing .item() which would incur gpu-cpu…
hebiao064 c478e75
Merge branch 'main' into kto_loss
hebiao064 06b2350
Fix checkstyle
hebiao064 3cf3771
Remove unnecessary change from conflict merge
hebiao064 6d33947
Merge branch 'main' into kto_loss
hebiao064 26f48d0
Merge branch 'main' into kto_loss
hebiao064 71b1773
Fix comments
hebiao064 c5e9c36
Merge branch 'main' into kto_loss
hebiao064 dc934e2
Merge branch 'main' into kto_loss
lancerts 4b7dfcf
Merge branch 'main' into kto_loss
hebiao064 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ site/ | |
| .venv/ | ||
| venv/ | ||
| .ipynb_checkpoints/ | ||
| .vscode/ | ||
|
|
||
| # Misc | ||
| .DS_Store | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| ## Benchmarking Liger Kernels | ||
|
|
||
| Follow these steps to benchmark and visualize kernel performance: | ||
|
|
||
| 1. Create a benchmark script | ||
| - Add your script under `benchmark/scripts/` | ||
| - Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`) | ||
|
|
||
| 2. Run the benchmark | ||
| - Results will be saved to `benchmark/data/all_benchmark_data.csv` | ||
|
|
||
| Example: Benchmarking KTO Loss | ||
| ```bash | ||
| cd benchmark | ||
| python scripts/benchmark_kto_loss.py | ||
| ``` | ||
|
|
||
| 3. Visualize results | ||
| - Use the visualization script with appropriate parameters | ||
|
|
||
| Example: Visualizing KTO Loss benchmark results | ||
| ```bash | ||
| python benchmarks_visualizer.py \ | ||
| --kernel-name kto_loss \ | ||
| --metric-name memory \ | ||
| --kernel-operation-mode full | ||
| ``` | ||
|
|
||
| 4. View results | ||
| - Generated plots will be saved in `benchmark/visualizations/` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,264 @@ | ||
| import os | ||
| import sys | ||
|
|
||
| import torch | ||
| import triton | ||
| from utils import ( | ||
| QUANTILES, | ||
| SingleBenchmarkRunInput, | ||
| SingleBenchmarkRunOutput, | ||
| _test_memory, | ||
| parse_benchmark_script_args, | ||
| run_benchmarks, | ||
| ) | ||
|
|
||
| from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss | ||
| from liger_kernel.utils import infer_device | ||
|
|
||
| device = infer_device() | ||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
|
|
||
|
|
||
| class TorchKTOLoss(torch.nn.Module): | ||
| def __init__( | ||
| self, | ||
| H: int, | ||
| V: int, | ||
| dtype: torch.dtype, | ||
| bias: bool = False, | ||
| ref_bias: bool = False, | ||
| ignore_index: int = -100, | ||
| beta: float = 0.1, | ||
| ): | ||
| from test.chunked_loss.test_kto_loss import HFKTOLoss | ||
|
|
||
| super().__init__() | ||
| self.lin = torch.nn.Linear( | ||
| in_features=H, out_features=V, bias=bias, dtype=dtype | ||
| ) | ||
| self.ref_lin = torch.nn.Linear( | ||
| in_features=H, out_features=V, bias=ref_bias, dtype=dtype | ||
| ) | ||
| self.kto_loss = HFKTOLoss( | ||
| ignore_index=ignore_index, beta=beta, use_ref_model=True | ||
| ).get_batch_loss_metrics | ||
|
|
||
| def forward(self, x, ref_x, y): | ||
| return self.kto_loss( | ||
| self.lin.weight, | ||
| x, | ||
| y, | ||
| self.lin.bias, | ||
| ref_x, | ||
| self.ref_lin.weight, | ||
| self.ref_lin.bias, | ||
| )[0] | ||
|
|
||
|
|
||
| class LigerKTOLoss(torch.nn.Module): | ||
| def __init__( | ||
| self, | ||
| H: int, | ||
| V: int, | ||
| dtype: torch.dtype, | ||
| bias: bool = False, | ||
| ref_bias: bool = False, | ||
| ignore_index: int = -100, | ||
| beta: float = 0.1, | ||
| ): | ||
| super().__init__() | ||
| self.lin = torch.nn.Linear( | ||
| in_features=H, out_features=V, bias=bias, dtype=dtype | ||
| ) | ||
| self.ref_lin = torch.nn.Linear( | ||
| in_features=H, out_features=V, bias=ref_bias, dtype=dtype | ||
| ) | ||
| self.kto_loss = LigerFusedLinearKTOLoss( | ||
| ignore_index=ignore_index, beta=beta, use_ref_model=True | ||
| ) | ||
|
|
||
| def forward(self, x, ref_x, y): | ||
| return self.kto_loss( | ||
| self.lin.weight, | ||
| x, | ||
| y, | ||
| self.lin.bias, | ||
| ref_x, | ||
| self.ref_lin.weight, | ||
| self.ref_lin.bias, | ||
| )[0] | ||
|
|
||
|
|
||
| def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| B = input.x | ||
| T = input.extra_benchmark_config["T"] | ||
| H = input.extra_benchmark_config["H"] | ||
| V = input.extra_benchmark_config["V"] | ||
| dtype = input.extra_benchmark_config["dtype"] | ||
| bias = input.extra_benchmark_config["bias"] | ||
| beta = input.extra_benchmark_config["beta"] | ||
| ignore_index = input.extra_benchmark_config["ignore_index"] | ||
| provider = input.kernel_provider | ||
|
|
||
| torch_kto_loss = TorchKTOLoss( | ||
| H=H, | ||
| V=V, | ||
| dtype=dtype, | ||
| bias=bias, | ||
| ref_bias=bias, | ||
| ignore_index=ignore_index, | ||
| beta=beta, | ||
| ).to(device) | ||
|
|
||
| liger_kto_loss = LigerKTOLoss( | ||
| H=H, | ||
| V=V, | ||
| dtype=dtype, | ||
| bias=bias, | ||
| ref_bias=bias, | ||
| ignore_index=ignore_index, | ||
| beta=beta, | ||
| ).to(device) | ||
|
|
||
| # Input shape: [B, T, H] | ||
| _input = torch.randn(B, T, H, device=device, dtype=dtype) | ||
| # Target shape: [B, T] | ||
| target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
|
||
| # Add ignore_index tokens to simulate padding | ||
| num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() | ||
| indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] | ||
| target.view(-1)[indices_to_assign] = ignore_index | ||
|
|
||
| # Add ref_x with the same shape as _input | ||
| ref_input = torch.randn(B, T, H, device=device, dtype=dtype) | ||
|
|
||
| def fwd(): | ||
| if provider == "liger": | ||
| return liger_kto_loss(_input, ref_input, target) | ||
| elif provider == "huggingface": | ||
| return torch_kto_loss(_input, ref_input, target) | ||
|
|
||
| def full(): | ||
| y = fwd() | ||
| y.backward() | ||
|
|
||
| mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | ||
| return SingleBenchmarkRunOutput( | ||
| y_20=mem_20, | ||
| y_50=mem_50, | ||
| y_80=mem_80, | ||
| ) | ||
|
|
||
|
|
||
| def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| B = input.x | ||
| T = input.extra_benchmark_config["T"] | ||
| H = input.extra_benchmark_config["H"] | ||
| V = input.extra_benchmark_config["V"] | ||
| dtype = input.extra_benchmark_config["dtype"] | ||
| bias = input.extra_benchmark_config["bias"] | ||
| beta = input.extra_benchmark_config["beta"] | ||
| ignore_index = input.extra_benchmark_config["ignore_index"] | ||
| provider = input.kernel_provider | ||
| mode = input.kernel_operation_mode | ||
|
|
||
| torch_kto_loss = TorchKTOLoss( | ||
| H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | ||
| ).to(device) | ||
| liger_kto_loss = LigerKTOLoss( | ||
| H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | ||
| ).to(device) | ||
|
|
||
| # Input shape: [B, T, H] | ||
| _input = torch.randn(B, T, H, device=device, dtype=dtype) | ||
|
|
||
| # Target shape: [B, T] | ||
| target = torch.randint(V, (B, T), device=device, dtype=torch.long) | ||
|
|
||
| # Add ignore_index tokens | ||
| num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() | ||
| indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] | ||
| target.view(-1)[indices_to_assign] = ignore_index | ||
|
|
||
| # Add ref_x with the same shape as _input | ||
| ref_input = torch.randn(B, T, H, device=device, dtype=dtype) | ||
|
|
||
| def fwd(): | ||
| if provider == "liger": | ||
| return liger_kto_loss(_input, ref_input, target) | ||
| elif provider == "huggingface": | ||
| return torch_kto_loss(_input, ref_input, target) | ||
|
|
||
| if mode == "forward": | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| fwd, | ||
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| elif mode == "backward": | ||
| y = fwd() | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| lambda: y.backward(retain_graph=True), | ||
| grad_to_none=[_input], | ||
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| elif mode == "full": | ||
|
|
||
| def full(): | ||
| y = fwd() | ||
| y.backward() | ||
|
|
||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| full, | ||
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
|
|
||
| return SingleBenchmarkRunOutput( | ||
| y_20=ms_20, | ||
| y_50=ms_50, | ||
| y_80=ms_80, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_benchmark_script_args() | ||
|
|
||
| common_configs = { | ||
| "kernel_name": "kto_loss", | ||
| "x_name": "B", | ||
| "x_label": "Batch Size (B)", | ||
| "x_values": [2**i for i in range(1, 6)], | ||
| "kernel_providers": ["liger", "huggingface"], | ||
| "extra_benchmark_configs": [ | ||
| { | ||
| "T": 512, | ||
| "H": 1024, | ||
| "V": 128256, | ||
| "mode": "forward", | ||
| "dtype": torch.bfloat16, | ||
| "bias": True, | ||
| "beta": 0.1, | ||
| "ignore_index": 42, | ||
| } | ||
| ], | ||
| "overwrite": args.overwrite, | ||
| } | ||
|
|
||
| run_benchmarks( | ||
| bench_test_fn=bench_speed_kto_loss, | ||
| kernel_operation_modes=["forward", "full"], | ||
| metric_name="speed", | ||
| metric_unit="ms", | ||
| **common_configs | ||
| ) | ||
|
|
||
| run_benchmarks( | ||
| bench_test_fn=bench_memory_kto_loss, | ||
| kernel_operation_modes=["full"], | ||
| metric_name="memory", | ||
| metric_unit="MB", | ||
| **common_configs | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 | ||
| from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 | ||
| from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 | ||
| from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 | ||
| from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.