Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all Slerp operations use PyTorch rather than Numpy for simplicity #256

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 22 additions & 37 deletions mergekit/merge_methods/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch

from mergekit.architecture import WeightInfo
Expand Down Expand Up @@ -85,15 +84,15 @@ def make_task(


def lerp(
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
t: float, v0: Union[torch.tensor, torch.Tensor], v1: Union[torch.tensor, torch.Tensor]
) -> Union[torch.tensor, torch.Tensor]:
return (1 - t) * v0 + t * v1


def slerp(
t: Union[float, np.ndarray],
v0: Union[np.ndarray, torch.Tensor],
v1: Union[np.ndarray, torch.Tensor],
t: Union[float, torch.tensor],
v0: Union[torch.tensor, torch.Tensor],
v1: Union[torch.tensor, torch.Tensor],
DOT_THRESHOLD: float = 0.9995,
eps: float = 1e-8,
):
Expand All @@ -102,62 +101,48 @@ def slerp(

From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
t (float/torch.tensor): Float value between 0.0 and 1.0
v0 (torch.tensor): Starting vector
v1 (torch.tensor): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colinear. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
v2 (torch.tensor): Interpolation vector between v0 and v1
"""
is_torch = False
if not isinstance(v0, np.ndarray):
is_torch = True
v0 = v0.detach().cpu().float().numpy()
if not isinstance(v1, np.ndarray):
is_torch = True
v1 = v1.detach().cpu().float().numpy()

# Copy the vectors to reuse them later
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
v0_copy = torch.clone(v0)
v1_copy = torch.clone(v1)

# Normalize the vectors to get the directions and angles
v0 = normalize(v0, eps)
v1 = normalize(v1, eps)

# Dot product with the normalized vectors (can't use np.dot in W)
dot = np.sum(v0 * v1)
# Dot product with the normalized vectors (can't use torch.dot in W)
dot = torch.sum(v0 * v1)

# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
if np.abs(dot) > DOT_THRESHOLD:
if torch.abs(dot) > DOT_THRESHOLD:
res = lerp(t, v0_copy, v1_copy)
return maybe_torch(res, is_torch)
return res

# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)

# Angle at timestep t
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
sin_theta_t = torch.sin(theta_t)

# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
res = s0 * v0_copy + s1 * v1_copy

return maybe_torch(res, is_torch)


def maybe_torch(v: np.ndarray, is_torch: bool):
if is_torch:
return torch.from_numpy(v)
return v
return res


def normalize(v: np.ndarray, eps: float):
norm_v = np.linalg.norm(v)
def normalize(v: torch.tensor, eps: float):
norm_v = torch.linalg.norm(v)
if norm_v > eps:
v = v / norm_v
return v
Loading