Skip to content

Keypoint-Guided Optimal Transport (NeurIPS 2022)

Notifications You must be signed in to change notification settings

XJTU-XGU/KPG-RL

Repository files navigation

Keypoint-Guided Optimal Transport (NeurIPS)

Official codes for conference paper "Keypoint-Guided Optimal Transport with Applications in Heterogeneous Domain Adaptation" and journal version "Keypoint-Guided Optimal Transport".

We presented a new optimal transport model named KPG-RL that leverages a few prior annotated keypoints to guide the matching of other points in OT. We propose a mask-based modeling of the transport plan and to preserve the relation of each data point to keypoints to realize the guidance.

Mask-based transport plan:

Relation modeling:

With the keypoints, our approach apparently improves the matching accuracy.

Requirements

python3.6
scipy==1.7.1
numpy==1.20.3
matplotlib==3.4.3
cvxpy==1.2.0
pot==0.8.1.0
pytorch==1.5.0

Instructions for the folders

HDA: codes for HDA experiments (Section 5.2)
Open-Set HDA: codes for Open-Set HDA experiments (Section 5.3)
UDA: codes for unsupervised DA experiments (Appendix B.3)
toy_data: codes for toy experiments for KPG-RL (Section 5.1)
toy_data_partial: codes for toy experiments for Partial KPG-RL (Appendix B.1)

Please follow the Readme.md in each folder to run the codes.

News

[2023.04] We provide a easy-to-follow demo in demo.ipynb, which includes the model, algorithm, and codes. Welcome to try it!

[2023.06] We upload the keypoint-guided OT packadge in folder "keypoint_guided_optimal_transport" and examples in "Examples_of_KPG_OT.py".

[2023.08] The linear programming implementation on sparse matrixes is available in "keypoint_guided_optimal_transport/linearprog.py", which is memory-efficient for a large number of data points. Please set "sparse=True" in function "lp" to use it as follows:

lp(p, q, C, Mask=None,sparse=True)

Using keypoint-guided OT in you code

For KPG-RL model, use the following code

# import keypoint-guided OT
from keypoint_guided_optimal_transport.keypoint_guided_OT import KeyPointGuidedOT

# define the samples xs and xt , the mass p and q, and the keypoint index pair K
xs =  # ndarray with shape (m,d)
xt =  # ndarray with shape (n,d)
p =   # ndarray with shape (m,)
q =   # ndarray with shape (n,)
K =   # list of tuples, e.g., [(0,0),(10,20)]

kgot = KeyPointGuidedOT()
pi = kgot.kpg_rl(p,q,xs,xt,K,cost_function="L2",algorithm="linear_programming",tau_s=0.1,tau_t=0.1,normalized=True,
               reg=0.0001,max_iterations=100000,thres=1e-5,eps=1e-10)
# pi is a ndarray with shape (m,n). The algorithm could be "linear_programming" or "sinkhorn".

For KPG-RL-KP model, use the following code

# import keypoint-guided OT
from keypoint_guided_optimal_transport.keypoint_guided_OT import KeyPointGuidedOT

# define the samples xs and xt , the mass p and q, and the keypoint index pair K
xs =  # ndarray with shape (m,d)
xt =  # ndarray with shape (n,d)
p =   # ndarray with shape (m,)
q =   # ndarray with shape (n,)
K =   # list of tuples, e.g., [(0,0),(10,20)]
alpha =   # combination coeffecient

kgot = KeyPointGuidedOT()
pi = kgot.kpg_rl_kp(p,q,xs,xt,K,alpha=alpha,cost_function="L2",algorithm="linear_programming",tau_s=0.1,tau_t=0.1,
                  normalized=True,reg=0.0001,max_iterations=100000,eps=1e-10,thres=1e-5)
# pi is a ndarray with shape (m,n). The algorithm could be "linear_programming" or "sinkhorn".

For KPG-RL-GW model, use the following code

# import keypoint-guided OT
from keypoint_guided_optimal_transport.keypoint_guided_OT import KeyPointGuidedOT

# define the samples xs and xt , the mass p and q, and the keypoint index pair K
xs =  # ndarray with shape (m,d)
xt =  # ndarray with shape (n,d)
p =   # ndarray with shape (m,)
q =   # ndarray with shape (n,)
K =   # list of tuples, e.g., [(0,0),(10,20)]
alpha =   # combination coeffecient

kgot = KeyPointGuidedOT()
pi = kgot.kpg_rl_gw(p,q,xs,xt,K,alpha=alpha,cost_function="L2",algorithm="linear_programming",tau_s=0.1,tau_t=0.1,
                  normalized=True,reg=0.0001,max_iterations=100000,eps=1e-10,thres=1e-5)
# pi is a ndarray with shape (m,n). The algorithm could be "linear_programming" or "sinkhorn".

For partial-KPG-RL model, use the following code

# import keypoint-guided OT
from keypoint_guided_optimal_transport.keypoint_guided_OT import KeyPointGuidedOT

# define the samples xs and xt , the mass p and q, and the keypoint index pair K
xs =  # ndarray with shape (m,d)
xt =  # ndarray with shape (n,d)
p =   # ndarray with shape (m,)
q =   # ndarray with shape (n,)
K =   # list of tuples, e.g., [(0,0),(10,20)]
s =   # total mass 

kgot = KeyPointGuidedOT()
pi = kgot.partial_kpg_rl(p, q, xs, xt, K, s=0.5, cost_function="L2", tau_s=1.0,
                  tau_t=1.0, normalized=False,eps=1e-10)

Citation:

@inproceedings{
gu2022keypointguided,
title={Keypoint-Guided Optimal Transport with Applications in Heterogeneous Domain Adaptation},
author={Xiang Gu and Yucheng Yang and Wei Zeng and Jian Sun and Zongben Xu},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=m6DJxSuKuqF}
}

@misc{gu2023keypointguided,
      title={Keypoint-Guided Optimal Transport}, 
      author={Xiang Gu and Yucheng Yang and Wei Zeng and Jian Sun and Zongben Xu},
      year={2023},
      eprint={2303.13102},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Contact

For any problem, please do not hesitate to contact [email protected].

The code is based on code of POT packadge.

About

Keypoint-Guided Optimal Transport (NeurIPS 2022)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published