Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs]: fix the bug in cal_fvd.py #247

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions docs/EVAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@ You can easily calculate the following video quality metrics, which supports the
pip install ftfy
pip install regex
pip install tqdm
pip install cupy
```
## Pretrain model
- FVD
Before you cacluate FVD, you should first download the FVD pre-trained model. You can manually download any of the following and put it into FVD folder.
- `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt)
- `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI)


## Other Notices
1. Make sure the pixel value of videos should be in [0, 1].
Expand Down Expand Up @@ -82,28 +77,21 @@ You can easily calculate the following video quality metrics, which supports the
# you change the file path and need to set the frame_num, resolution etc...

# clip_score cross modality
cd opensora/eval
bash script/cal_clip_score.sh


bash scripts/eval/cal_clip_score.sh

# fvd
cd opensora/eval
bash script/cal_fvd.sh
# fvd
bash scripts/eval/cal_fvd.sh

# psnr
cd opensora/eval
bash eval/script/cal_psnr.sh
bash scripts/eval/cal_psnr.sh


# ssim
cd opensora/eval
bash eval/script/cal_ssim.sh
bash scripts/eval/cal_ssim.sh


# lpips
cd opensora/eval
bash eval/script/cal_lpips.sh
bash scripts/eval/cal_lpips.sh
```

# Acknowledgement
Expand Down
Binary file not shown.
222 changes: 141 additions & 81 deletions opensora/eval/cal_fvd.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,145 @@

import numpy as np
import scipy
from torch.utils.data import DataLoader, TensorDataset
import torch
from tqdm import tqdm

def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)

# permute BTCHW -> BCTHW
x = x.permute(0, 2, 1, 3, 4)

return x

def calculate_fvd(videos1, videos2, device, method='styleganv'):

if method == 'styleganv':
from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
elif method == 'videogpt':
from fvd.videogpt.fvd import load_i3d_pretrained
from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
from fvd.videogpt.fvd import frechet_distance

import hashlib
import os
import glob
import requests
import re
import html
import io
import uuid

_feature_detector_cache = dict()

# this is a helper function that allows to download a file from the internet cache it and open it as if it was a normal file
def open_url(url, num_attempts=10, verbose=False, cache_dir=None):
assert num_attempts >=1

if cache_dir is None:
cache_dir = './loaded_models'
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
if len(cache_files) == 1:
f_name = cache_files[0]
return open(f_name, 'rb')

with requests.Session() as session:
if verbose:
print("Downloading ", url, flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError("Google Drive download quota exceeded -- please try again later")

match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise Exception("Interupted")
except:
if not attempts_left:
if verbose:
print("failed!")
raise
if verbose:
print('.')

safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
os.makedirs(cache_dir, exist_ok=True)
with open(temp_file, 'wb') as f:
f.write(url_data)
os.replace(temp_file, cache_file)

return io.BytesIO(url_data)

# load the feature extractor either from cache or the specified URL
def get_feature_detector(detector_url, device):
key = (detector_url, device)
if key not in _feature_detector_cache:
with open_url(detector_url, verbose=True) as f:
_feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
return _feature_detector_cache[key]


"""
This function is used to first extract feature representation vectors of the videos using a pretrained model
Then the mean and covariance of the representation vectors are calculated and returned
"""
def compute_feature_stats(data, detector_url, detector_kwargs, batch_size, max_items, device):
# if wanted reduce the number of elements used for calculating the FVD
num_items = len(data)
if max_items:
num_items = min(num_items, max_items)
data = data[:num_items]

# load the pretrained feature extraction modeö
detector = get_feature_detector(detector_url, device=device)

dataset = TensorDataset(data)
loader = DataLoader(dataset, batch_size=batch_size)
all_features = []
for batch in loader:
batch = batch[0]
# if more than 3 channels are available we split the channel dimension into chunks of 3 and concatenate to batch dimension
if batch.size(1) != 3:
pad_size = 3 - (batch.size(1) % 3)
pad = torch.zeros(batch.size(0), pad_size, batch.size(2), batch.size(3), batch.size(4), device=batch.device)
batch = torch.cat([batch, pad], dim=1)
batch = torch.cat(torch.chunk(batch, chunks=batch.size(1)//3, dim=1), dim=0)
batch = batch.to(device)
# extract feature vector using pretrained model
features = detector(batch, **detector_kwargs)
features = features.detach().cpu().numpy()
all_features.append(features)
# concatenate batches to one numpy array
stacked_features = np.concatenate(all_features, axis=0)

# calculate mean and covariance matrix across the extracted features
mu = np.mean(stacked_features, axis=0)
sigma = np.cov(stacked_features, rowvar=False)

return mu, sigma

def calculate_fvd(y_true: torch.Tensor, y_pred: torch.Tensor, device: torch.device):
'''
y_true: (bz,c,t,h,w) `num_videos x channels x num_frames x width x height`
y_pred: (bz,c,t,h,w) `num_videos x channels x num_frames x width x height`
'''
# print(y_true.shape) # torch.Size([5, 20, 3, 64, 64])
# print(y_pred.shape) # torch.Size([5, 20, 3, 64, 64])
y_true = torch.permute(y_true,(0,2,1,3,4)).contiguous()
y_pred = torch.permute(y_pred,(0,2,1,3,4)).contiguous()

batch_size = y_true.shape[0]
max_items = batch_size
print("calculate_fvd...")
detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.

# videos [batch_size, timestamps, channel, h, w]

assert videos1.shape == videos2.shape

i3d = load_i3d_pretrained(device=device)
fvd_results = []

# support grayscale input, if grayscale -> channel*3
# BTCHW -> BCTHW
# videos -> [batch_size, channel, timestamps, h, w]

videos1 = trans(videos1)
videos2 = trans(videos2)

fvd_results = {}

# for calculate FVD, each clip_timestamp must >= 10
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):

# get a video clip
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
videos_clip1 = videos1[:, :, : clip_timestamp]
videos_clip2 = videos2[:, :, : clip_timestamp]

# get FVD features
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)

# calculate FVD when timestamps[:clip]
fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)

result = {
"value": fvd_results,
"video_setting": videos1.shape,
"video_setting_name": "batch_size, channel, time, heigth, width",
}

return result

# test code / using example

def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")

import json
result = calculate_fvd(videos1, videos2, device, method='videogpt')
print(json.dumps(result, indent=4))

result = calculate_fvd(videos1, videos2, device, method='styleganv')
print(json.dumps(result, indent=4))

if __name__ == "__main__":
main()
# calculate the mean and covariance matrix of the representation vectors for ground truth and predicted videos
mu_true, sigma_true = compute_feature_stats(y_true, detector_url, detector_kwargs, batch_size, max_items, device)
mu_pred, sigma_pred = compute_feature_stats(y_pred, detector_url, detector_kwargs, batch_size, max_items, device)
m = np.square(mu_pred - mu_true).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_pred, sigma_true), disp=False)
fvd = np.real(m + np.trace(sigma_pred + sigma_true - s * 2))

return float(fvd)
5 changes: 1 addition & 4 deletions opensora/eval/eval_common_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,17 @@ def _preprocess(video_data, short_size=128, crop_size=None):
]
)
video_outputs = transform(video_data)
# video_outputs = torch.unsqueeze(video_outputs, 0) # (bz,c,t,h,w)
return video_outputs


def calculate_common_metric(args, dataloader, device):

score_list = []
for batch_data in tqdm(dataloader): # {'real': real_video_tensor, 'generated':generated_video_tensor }
real_videos = batch_data['real']
generated_videos = batch_data['generated']
assert real_videos.shape[2] == generated_videos.shape[2]
if args.metric == 'fvd':
tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values())
tmp_list = [calculate_fvd(real_videos, generated_videos, args.device)]
elif args.metric == 'ssim':
tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values())
elif args.metric == 'psnr':
Expand Down Expand Up @@ -178,7 +176,6 @@ def main():
parser.add_argument('--sample_rate', type=int, default=1)
parser.add_argument('--subset_size', type=int, default=None)
parser.add_argument("--metric", type=str, default="fvd",choices=['fvd','psnr','ssim','lpips', 'flolpips'])
parser.add_argument("--fvd_method", type=str, default='styleganv',choices=['styleganv','videogpt'])


args = parser.parse_args()
Expand Down
90 changes: 0 additions & 90 deletions opensora/eval/fvd/styleganv/fvd.py

This file was deleted.

Loading