Files
TRELLIS.2/data_toolkit/encode_ss_latent.py
2026-01-10 09:47:30 +00:00

164 lines
7.7 KiB
Python

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import json
import argparse
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis2.models as models
torch.set_grad_enabled(False)
def is_valid_sparse_tensor(tensor):
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
def clear_cuda_error():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--shape_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--ss_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--resolution', type=int, default=64,
help='Sparse voxel resolution')
parser.add_argument('--shape_latent_name', type=str, default=None,
help='Name of the shape latent files')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str,
help='Root directory of models')
parser.add_argument('--enc_model', type=str,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.shape_latent_root = opt.shape_latent_root or opt.root
opt.ss_latent_root = opt.ss_latent_root or opt.root
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name,'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.ss_latent_root,'ss_latents', latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.ss_latent_root,'ss_latents', latent_name,'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['shape_latent_encoded'] == True]
if 'ss_latent_encoded' in metadata.columns:
metadata = metadata[metadata['ss_latent_encoded'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256_list = os.listdir(os.path.join(opt.ss_latent_root, 'ss_latents'))
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.npz')]
for sha256 in sha256_list:
records.append({'sha256': sha256, 'ss_latent_encoded': True})
print(f'Found {len(sha256_list)} processed objects')
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
print(f'Processing {len(metadata)} objects...')
sha256s = list(metadata['sha256'].values)
load_queue = Queue(maxsize=32)
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
coords = np.load(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, f'{sha256}.npz'))['coords']
assert np.all(coords < opt.resolution), f"{sha256}: Invalid coords"
coords = torch.from_numpy(coords).long()
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
load_queue.put((sha256, ss))
except Exception as e:
print(f"[Loader Error] {sha256}: {e}")
load_queue.put((sha256, None))
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, 'ss_latent_encoded': True})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
try:
sha256, ss = load_queue.get()
if ss is None:
print(f"[Skip] {sha256}: Failed to load input")
continue
ss = ss.cuda()[None].float()
z = encoder(ss, sample_posterior=False)
torch.cuda.synchronize()
if not torch.isfinite(z).all():
print(f"[Skip] {sha256}: Non-finite latent")
clear_cuda_error()
continue
pack = {
'z': z[0].cpu().numpy(),
}
saver_executor.submit(saver, sha256, pack)
except Exception as e:
print(f"[Error] {sha256}: {e}")
clear_cuda_error()
continue
saver_executor.shutdown(wait=True)
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)