Release Training Code

This commit is contained in:
JeffreyXiang
2026-01-10 09:47:30 +00:00
parent 903bfcf51a
commit 5565d240c4
36 changed files with 4853 additions and 24 deletions

113
README.md
View File

@@ -50,7 +50,7 @@ Data processing is streamlined for instant conversions that are fully **renderin
- [x] Release pretrained checkpoints (4B)
- [x] Hugging Face Spaces demo
- [x] Release shape-conditioned texture generation inference code
- [ ] Release training code (Current schdule: before 12/31/2025)
- [x] Release training code
## 🛠️ Installation
@@ -186,6 +186,117 @@ Then, you can access the demo at the address shown in the terminal.
Please refer to the [example_texturing.py](example_texturing.py) for an example of how to generate PBR textures for a given 3D shape. Also, you can use the [app_texturing.py](app_texturing.py) to run a web demo for PBR texture generation.
## 🏋️ Training
We provide the full training codebase, enabling users to train **TRELLIS.2** from scratch or fine-tune it on custom datasets.
### 1. Data Preparation
Before training, raw 3D assets must be converted into the **O-Voxel** representation. This process includes mesh conversion, compact structured latent generation, and metadata preparation.
> 📂 **Please refer to [data_toolkit/README.md](data_toolkit/README.md) for detailed instructions on data preprocessing and dataset organization.**
### 2. Running Training
Training is managed through the `train.py` script, which accepts multiple command-line arguments to configure experiments:
* `--config`: Path to the experiment configuration file.
* `--output_dir`: Directory for training outputs.
* `--load_dir`: Directory to load checkpoints from (defaults to `output_dir`).
* `--ckpt`: Checkpoint step to resume from (defaults to the latest).
* `--data_dir`: Dataset path or a JSON string specifying dataset locations.
* `--auto_retry`: Number of automatic retries upon failure.
* `--tryrun`: Perform a dry run without actual training.
* `--profile`: Enable training profiling.
* `--num_nodes`: Number of nodes for distributed training.
* `--node_rank`: Rank of the current node.
* `--num_gpus`: Number of GPUs per node (defaults to all available GPUs).
* `--master_addr`: Master node address for distributed training.
* `--master_port`: Port for distributed training communication.
### SC-VAE Training
To train the shape SC-VAE, run:
```sh
python train.py \
--config configs/scvae/shape_vae_next_dc_f16c32_fp16.json \
--output_dir results/shape_vae_next_dc_f16c32_fp16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"mesh_dump\": \"datasets/ObjaverseXL_sketchfab/mesh_dumps\", \"dual_grid\": \"datasets/ObjaverseXL_sketchfab/dual_grid_256\", \"asset_stats\": \"datasets/ObjaverseXL_sketchfab/asset_stats\"}}"
```
This command trains the shape SC-VAE on the **Objaverse-XL** dataset using the `shape_vae_next_dc_f16c32_fp16.json` configuration. Training outputs will be saved to `results/shape_vae_next_dc_f16c32_fp16`.
The dataset is specified as a JSON string, where each dataset entry includes:
* `base`: Root directory of the dataset.
* `mesh_dump`: Directory containing preprocessed mesh dumps.
* `dual_grid`: Directory with precomputed dual-grid representations.
* `asset_stats`: Directory containing precomputed asset statistics.
To fine-tune the model at a higher resolution, use the `shape_vae_next_dc_f16c32_fp16_ft_512.json` configuration. Remember to update the `finetune_ckpt` field and adjust the dataset paths accordingly.
To train the texture SC-VAE, run:
```sh
python train.py \
--config configs/scvae/tex_vae_next_dc_f16c32_fp16.json \
--output_dir results/tex_vae_next_dc_f16c32_fp16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"pbr_dump\": \"datasets/ObjaverseXL_sketchfab/pbr_dumps\", \"pbr_voxel\": \"datasets/ObjaverseXL_sketchfab/pbr_voxels_256\", \"asset_stats\": \"datasets/ObjaverseXL_sketchfab/asset_stats\"}}"
```
### Flow Model Training
To train the sparse structure flow model, run:
```sh
python train.py \
--config configs/gen/ss_flow_img_dit_1_3B_64_bf16.json \
--output_dir results/ss_flow_img_dit_1_3B_64_bf16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"ss_latent\": \"datasets/ObjaverseXL_sketchfab/ss_latents/ss_enc_conv3d_16l8_fp16_64\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
```
This command trains the sparse-structure flow model on the **Objaverse-XL** dataset using the specified configuration file. Outputs are saved to `results/ss_flow_img_dit_1_3B_64_bf16`.
The dataset configuration includes:
* `base`: Root dataset directory.
* `ss_latent`: Directory containing precomputed sparse-structure latents.
* `render_cond`: Directory containing conditional rendering images.
The second- and third-stage flow models for shape and texture generation can be trained using the following configurations:
* Shape flow: `slat_flow_img2shape_dit_1_3B_512_bf16.json`
* Texture flow: `slat_flow_imgshape2tex_dit_1_3B_512_bf16.json`
Example commands:
```sh
# Shape flow model
python train.py \
--config configs/gen/slat_flow_img2shape_dit_1_3B_512_bf16.json \
--output_dir results/slat_flow_img2shape_dit_1_3B_512_bf16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"shape_latent\": \"datasets/ObjaverseXL_sketchfab/shape_latents/shape_enc_next_dc_f16c32_fp16_512\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
# Texture flow model
python train.py \
--config configs/gen/slat_flow_imgshape2tex_dit_1_3B_512_bf16.json \
--output_dir results/slat_flow_imgshape2tex_dit_1_3B_512_bf16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"shape_latent\": \"datasets/ObjaverseXL_sketchfab/shape_latents/shape_enc_next_dc_f16c32_fp16_512\", \"pbr_latent\": \"datasets/ObjaverseXL_sketchfab/pbr_latents/tex_enc_next_dc_f16c32_fp16_512\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
```
Higher-resolution fine-tuning can be performed by updating the `finetune_ckpt` field in the following configuration files and adjusting the dataset paths accordingly:
* `slat_flow_img2shape_dit_1_3B_512_bf16_ft1024.json`
* `slat_flow_imgshape2tex_dit_1_3B_512_bf16_ft1024.json`
## 🧩 Related Packages
TRELLIS.2 is built upon several specialized high-performance packages developed by our team:

View File

@@ -0,0 +1,98 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 32,
"in_channels": 32,
"out_channels": 32,
"model_channels": 1536,
"cond_channels": 1024,
"num_blocks": 30,
"num_heads": 12,
"mlp_ratio": 5.3334,
"pe_mode": "rope",
"share_mod": true,
"initialization": "scaled",
"qk_rms_norm": true,
"qk_rms_norm_cross": true
}
}
},
"dataset": {
"name": "ImageConditionedSLatShape",
"args": {
"resolution": 512,
"image_size": 512,
"min_aesthetic_score": 4.5,
"max_tokens": 8192,
"normalization": {
"mean": [
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
],
"std": [
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
]
},
"pretrained_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
}
},
"trainer": {
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 8,
"batch_split": 2,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-4,
"weight_decay": 0.01,
"betas": [0.9, 0.95],
"eps": 1e-8
}
},
"ema_rate": [
0.9999
],
"mix_precision_mode": "amp",
"mix_precision_dtype": "bfloat16",
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"p_uncond": 0.1,
"t_schedule": {
"name": "uniform",
"args": {}
},
"sigma_min": 1e-5,
"image_cond_model": {
"name": "DinoV3FeatureExtractor",
"args": {
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 512
}
}
}
}
}

View File

@@ -0,0 +1,101 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 64,
"in_channels": 32,
"out_channels": 32,
"model_channels": 1536,
"cond_channels": 1024,
"num_blocks": 30,
"num_heads": 12,
"mlp_ratio": 5.3334,
"pe_mode": "rope",
"share_mod": true,
"initialization": "scaled",
"qk_rms_norm": true,
"qk_rms_norm_cross": true
}
}
},
"dataset": {
"name": "ImageConditionedSLatShape",
"args": {
"resolution": 1024,
"image_size": 1024,
"min_aesthetic_score": 4.5,
"max_tokens": 32768,
"normalization": {
"mean": [
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
],
"std": [
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
]
},
"pretrained_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
}
},
"trainer": {
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 2,
"batch_split": 1,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 2e-5,
"weight_decay": 0.01,
"betas": [0.9, 0.95],
"eps": 1e-8
}
},
"ema_rate": [
0.9999
],
"mix_precision_mode": "amp",
"mix_precision_dtype": "bfloat16",
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.25
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"finetune_ckpt": {
"denoiser": "PATH_TO_512_CKPT"
},
"i_log": 500,
"i_sample": 1000,
"i_save": 1000,
"p_uncond": 0.1,
"t_schedule": {
"name": "uniform",
"args": {}
},
"sigma_min": 1e-5,
"image_cond_model": {
"name": "DinoV3FeatureExtractor",
"args": {
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 1024
}
}
}
}
}

View File

@@ -0,0 +1,119 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 32,
"in_channels": 64,
"out_channels": 32,
"model_channels": 1536,
"cond_channels": 1024,
"num_blocks": 30,
"num_heads": 12,
"mlp_ratio": 5.3334,
"pe_mode": "rope",
"share_mod": true,
"initialization": "scaled",
"qk_rms_norm": true,
"qk_rms_norm_cross": true
}
}
},
"dataset": {
"name": "ImageConditionedSLatPbr",
"args": {
"resolution": 512,
"image_size": 512,
"min_aesthetic_score": 4.5,
"max_tokens": 8192,
"pbr_slat_normalization": {
"mean": [
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
],
"std": [
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
]
},
"shape_slat_normalization": {
"mean": [
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
],
"std": [
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
]
},
"attrs": [
"base_color",
"metallic",
"roughness",
"alpha"
],
"pretrained_pbr_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16",
"pretrained_shape_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
}
},
"trainer": {
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 8,
"batch_split": 2,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-4,
"weight_decay": 0.01,
"betas": [0.9, 0.95],
"eps": 1e-8
}
},
"ema_rate": [
0.9999
],
"mix_precision_mode": "amp",
"mix_precision_dtype": "bfloat16",
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"p_uncond": 0.1,
"t_schedule": {
"name": "uniform",
"args": {}
},
"sigma_min": 1e-5,
"image_cond_model": {
"name": "DinoV3FeatureExtractor",
"args": {
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 512
}
}
}
}
}

View File

@@ -0,0 +1,120 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 32,
"in_channels": 64,
"out_channels": 32,
"model_channels": 1536,
"cond_channels": 1024,
"num_blocks": 30,
"num_heads": 12,
"mlp_ratio": 5.3334,
"pe_mode": "rope",
"share_mod": true,
"initialization": "scaled",
"qk_rms_norm": true,
"qk_rms_norm_cross": true
}
}
},
"dataset": {
"name": "ImageConditionedSLatPbr",
"args": {
"resolution": 1024,
"image_size": 1024,
"min_aesthetic_score": 4.5,
"max_tokens": 32768,
"full_pbr": true,
"pbr_slat_normalization": {
"mean": [
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
],
"std": [
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
]
},
"shape_slat_normalization": {
"mean": [
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
],
"std": [
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
]
},
"attrs": [
"base_color",
"metallic",
"roughness",
"alpha"
],
"pretrained_pbr_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16",
"pretrained_shape_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
}
},
"trainer": {
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 2,
"batch_split": 1,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 2e-5,
"weight_decay": 0.01,
"betas": [0.9, 0.95],
"eps": 1e-8
}
},
"ema_rate": [
0.9999
],
"mix_precision_mode": "amp",
"mix_precision_dtype": "bfloat16",
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.25
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 1000,
"i_save": 1000,
"p_uncond": 0.1,
"t_schedule": {
"name": "uniform",
"args": {}
},
"sigma_min": 1e-5,
"image_cond_model": {
"name": "DinoV3FeatureExtractor",
"args": {
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 1024
}
}
}
}
}

View File

@@ -0,0 +1,78 @@
{
"models": {
"denoiser": {
"name": "SparseStructureFlowModel",
"args": {
"resolution": 16,
"in_channels": 8,
"out_channels": 8,
"model_channels": 1536,
"cond_channels": 1024,
"num_blocks": 30,
"num_heads": 12,
"mlp_ratio": 5.3334,
"pe_mode": "rope",
"share_mod": true,
"initialization": "scaled",
"qk_rms_norm": true,
"qk_rms_norm_cross": true
}
}
},
"dataset": {
"name": "ImageConditionedSparseStructureLatent",
"args": {
"min_aesthetic_score": 4.5,
"image_size": 512,
"pretrained_ss_dec": "microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
}
},
"trainer": {
"name": "ImageConditionedFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 8,
"batch_split": 4,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-4,
"weight_decay": 0.01,
"betas": [0.9, 0.95],
"eps": 1e-8
}
},
"ema_rate": [
0.9999
],
"mix_precision_mode": "amp",
"mix_precision_dtype": "bfloat16",
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"p_uncond": 0.1,
"t_schedule": {
"name": "logitNormal",
"args": {
"mean": 1.0,
"std": 1.0
}
},
"sigma_min": 1e-5,
"image_cond_model": {
"name": "DinoV3FeatureExtractor",
"args": {
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 512
}
}
}
}
}

View File

@@ -0,0 +1,134 @@
{
"models": {
"encoder": {
"name": "FlexiDualGridVaeEncoder",
"args": {
"model_channels": [64, 128, 256, 512, 1024],
"latent_channels": 32,
"num_blocks": [0, 4, 8, 16, 4],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"down_block_type": [
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d"
],
"block_args": [
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": false
},
{
"use_checkpoint": false
},
{
"use_checkpoint": false
}
],
"use_fp16": true
}
},
"decoder": {
"name": "FlexiDualGridVaeDecoder",
"args": {
"resolution": 256,
"model_channels": [1024, 512, 256, 128, 64],
"latent_channels": 32,
"num_blocks": [4, 16, 8, 4, 0],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"up_block_type": [
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d"
],
"block_args": [
{
"use_checkpoint": false
},
{
"use_checkpoint": false
},
{
"use_checkpoint": false
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
}
],
"use_fp16": true
}
}
},
"dataset": {
"name": "FlexiDualGridDataset",
"args": {
"resolution": 256,
"max_active_voxels": 1000000,
"max_num_faces": 1000000,
"min_aesthetic_score": 4.5
}
},
"trainer": {
"name": "ShapeVaeTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 8,
"batch_split": 2,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-4,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"lambda_subdiv": 0.1,
"lambda_intersected": 0.1,
"lambda_vertice": 1e-2,
"lambda_mask": 1,
"lambda_depth": 10,
"lambda_normal": 1,
"lambda_kl": 1e-6,
"lambda_ssim": 0.2,
"lambda_lpips": 0.2,
"camera_randomization_config": {
"radius_range": [2, 100]
}
}
}
}

View File

@@ -0,0 +1,140 @@
{
"models": {
"encoder": {
"name": "FlexiDualGridVaeEncoder",
"args": {
"model_channels": [64, 128, 256, 512, 1024],
"latent_channels": 32,
"num_blocks": [0, 4, 8, 16, 4],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"down_block_type": [
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d"
],
"block_args": [
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
}
],
"use_fp16": true
}
},
"decoder": {
"name": "FlexiDualGridVaeDecoder",
"args": {
"resolution": 512,
"model_channels": [1024, 512, 256, 128, 64],
"latent_channels": 32,
"num_blocks": [4, 16, 8, 4, 0],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"up_block_type": [
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d"
],
"block_args": [
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
}
],
"use_fp16": true
}
}
},
"dataset": {
"name": "FlexiDualGridDataset",
"args": {
"resolution": 512,
"max_active_voxels": 1000000,
"max_num_faces": 1000000,
"min_aesthetic_score": 4.5
}
},
"trainer": {
"name": "ShapeVaeTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 4,
"batch_split": 2,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-5,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"finetune_ckpt": {
"encoder": "PATH_TO_ENCODER_CKPT",
"decoder": "PATH_TO_DECODER_CKPT"
},
"snapshot_batch_size": 1,
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"lambda_subdiv": 0.1,
"lambda_intersected": 0.1,
"lambda_vertice": 1e-2,
"lambda_mask": 1,
"lambda_depth": 10,
"lambda_normal": 1,
"lambda_kl": 1e-6,
"lambda_ssim": 0.2,
"lambda_lpips": 0.2,
"render_resolution": 1024,
"camera_randomization_config": {
"radius_range": [2, 100]
}
}
}
}

View File

@@ -0,0 +1,134 @@
{
"models": {
"encoder": {
"name": "SparseUnetVaeEncoder",
"args": {
"in_channels": 6,
"model_channels": [64, 128, 256, 512, 1024],
"latent_channels": 32,
"num_blocks": [0, 4, 8, 16, 4],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"down_block_type": [
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d"
],
"block_args": [
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": false
},
{
"use_checkpoint": false
},
{
"use_checkpoint": false
}
],
"use_fp16": true
}
},
"decoder": {
"name": "SparseUnetVaeDecoder",
"args": {
"out_channels": 6,
"model_channels": [1024, 512, 256, 128, 64],
"latent_channels": 32,
"num_blocks": [4, 16, 8, 4, 0],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"up_block_type": [
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d"
],
"block_args": [
{
"use_checkpoint": false
},
{
"use_checkpoint": false
},
{
"use_checkpoint": false
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
}
],
"use_fp16": true,
"pred_subdiv": false
}
}
},
"dataset": {
"name": "SparseVoxelPbrDataset",
"args": {
"resolution": 256,
"min_aesthetic_score": 4.5,
"max_active_voxels": 1000000,
"max_num_faces": 1000000,
"with_mesh": false,
"attrs": [
"base_color",
"metallic",
"roughness",
"alpha"
]
}
},
"trainer": {
"name": "PbrVaeTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 8,
"batch_split": 1,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-4,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"lambda_kl": 1e-6,
"loss_type": "l1",
"lambda_render": 0.0
}
}
}

View File

@@ -0,0 +1,144 @@
{
"models": {
"encoder": {
"name": "SparseUnetVaeEncoder",
"args": {
"in_channels": 6,
"model_channels": [64, 128, 256, 512, 1024],
"latent_channels": 32,
"num_blocks": [0, 4, 8, 16, 4],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"down_block_type": [
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d",
"SparseResBlockS2C3d"
],
"block_args": [
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
}
],
"use_fp16": true
}
},
"decoder": {
"name": "SparseUnetVaeDecoder",
"args": {
"out_channels": 6,
"model_channels": [1024, 512, 256, 128, 64],
"latent_channels": 32,
"num_blocks": [4, 16, 8, 4, 0],
"block_type": [
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d",
"SparseConvNeXtBlock3d"
],
"up_block_type": [
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d",
"SparseResBlockC2S3d"
],
"block_args": [
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
},
{
"use_checkpoint": true
}
],
"use_fp16": true,
"pred_subdiv": false
}
}
},
"dataset": {
"name": "SparseVoxelPbrDataset",
"args": {
"resolution": 512,
"min_aesthetic_score": 4.5,
"max_active_voxels": 1000000,
"max_num_faces": 1000000,
"attrs": [
"base_color",
"metallic",
"roughness",
"alpha"
]
}
},
"trainer": {
"name": "PbrVaeTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 4,
"batch_split": 2,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-5,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"finetune_ckpt": {
"encoder": "PATH_TO_ENCODER_CKPT",
"decoder": "PATH_TO_DECODER_CKPT"
},
"snapshot_batch_size": 1,
"render_resolution": 512,
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"lambda_kl": 1e-6,
"lambda_render": 1.0,
"loss_type": "l1",
"lambda_ssim": 0.2,
"lambda_lpips": 0.2,
"camera_randomization_config": {
"radius_range": [2, 100]
}
}
}
}

172
data_toolkit/README.md Normal file
View File

@@ -0,0 +1,172 @@
# Dataset Preparation Toolkit
This toolkit provides a comprehensive pipeline for preparing 3D datasets, including downloading, processing, voxelizing, and latent encoding for SC-VAE and Flow Model training.
### Step 1: Install Dependencies
Initialize the environment and install necessary dependencies:
```bash
. ./data_toolkit/setup.sh
```
### Step 2: Initialize Metadata
Before processing, load the dataset metadata.
```bash
python data_toolkit/build_metadata.py <SUBSET> --root <ROOT> [--source <SOURCE>]
```
**Arguments:**
- `SUBSET`: Target dataset subset. Options: `ObjaverseXL`, `ABO`, `HSSD`, `TexVerse` (Training sets); `SketchfabPicked`, `Toys4k` (Test sets).
- `ROOT`: Root directory to save the data.
- `SOURCE`: Data source (Required if `SUBSET` is `ObjaverseXL`). Options: `sketchfab`, `github`.
**Example:**
Load metadata for `ObjaverseXL` (sketchfab) and save to `datasets/ObjaverseXL_sketchfab`:
```bash
python data_toolkit/build_metadata.py ObjaverseXL --source sketchfab --root datasets/ObjaverseXL_sketchfab
```
### Step 3: Download Data
Download the 3D assets to the local storage.
```bash
python data_toolkit/download.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
```
**Arguments:**
- `RANK` / `WORLD_SIZE`: Parameters for multi-node distributed downloading.
**Example:**
To download the `ObjaverseXL` subset:
> **Note:** The example below sets a large `WORLD_SIZE` (160,000) for demonstration purposes, meaning only a tiny fraction of the dataset will be downloaded by this single process.
```bash
python data_toolkit/download.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab --world_size 160000
```
*Attention: Some datasets may require an interactive Hugging Face login or manual steps. Please follow any on-screen instructions.*
**Update Metadata:**
After downloading, update the metadata registry:
```bash
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
```
### Step 4: Process Mesh and PBR Textures
Standardize 3D assets by dumping mesh and PBR textures.
*Note: This process utilizes the CPU.*
```bash
# Dump Meshes
python data_toolkit/dump_mesh.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
# Dump PBR Textures
python data_toolkit/dump_pbr.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
# Get statisitics of the asset
python asset_stats.py --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
```
**Example:**
```bash
python data_toolkit/dump_mesh.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
python data_toolkit/dump_pbr.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
python asset_stats.py --root datasets/ObjaverseXL_sketchfab
```
**Update Metadata:**
```bash
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
```
### Step 5: Convert to O-Voxels
Convert the processed meshes and textures into O-Voxels format.
*Note: This process utilizes the CPU.*
```bash
python data_toolkit/dual_grid.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
python data_toolkit/voxelize_pbr.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
```
**Arguments:**
- `RESOLUTION`: Target resolutions for O-Voxels, comma-separated (e.g., `256,512,1024`). Default is `256`.
**Example:**
Convert `ObjaverseXL` to resolutions 256, 512, and 1024:
```bash
python data_toolkit/dual_grid.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab --resolution 256,512,1024
python data_toolkit/voxelize_pbr.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab --resolution 256,512,1024
```
### At this point, the dataset is ready for SC-VAE Training
### Step 6: Encode Latents
Encode sparse structures into latents to train the first-stage generator.
```bash
# 1. Encode Shape Latents
python data_toolkit/encode_shape_latent.py --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
# 2. Encode PBR Latents
python data_toolkit/encode_pbr_latent.py --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
# 3. Update Metadata (Required before next step)
python data_toolkit/build_metadata.py <SUBSET> --root <ROOT>
# 4. Encode Sparse Structure (SS) Latents
python data_toolkit/encode_ss_latent.py --root <ROOT> --shape_latent_name <SHAPE_LATENT_NAME> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <SS_RESOLUTION>]
```
**Arguments:**
- `RESOLUTION`: Input O-Voxel resolution. Default is `1024`.
- `SS_RESOLUTION`: Resolution for sparse structures. Default is `64`.
- `SHAPE_LATENT_NAME`: The specific version name of the shape latent.
**Example:**
```bash
python data_toolkit/encode_shape_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 512
python data_toolkit/encode_pbr_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 512
python data_toolkit/encode_shape_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 1024
python data_toolkit/encode_pbr_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 1024
# Update metadata
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
# Encode SS Latents
python data_toolkit/encode_ss_latent.py --root datasets/ObjaverseXL_sketchfab --shape_latent_name shape_enc_next_dc_f16c32_fp16_1024 --resolution 64
# Final Metadata Update
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
```
### Step 7: Render Image Conditions
Render multi-view images to train the image-conditioned generator.
*Note: This process may utilize the CPU.*
```bash
python data_toolkit/render_cond.py <SUBSET> --root <ROOT> [--num_views <NUM_VIEWS>] [--rank <RANK> --world_size <WORLD_SIZE>]
```
**Arguments:**
- `NUM_VIEWS`: Number of views to render per asset. Default is `16`.
**Example:**
```bash
python data_toolkit/render_cond.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
```
**Final Metadata Update:**
```bash
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
```

132
data_toolkit/asset_stats.py Normal file
View File

@@ -0,0 +1,132 @@
import os
import argparse
import pickle
from tqdm import tqdm
import pandas as pd
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--mesh_dump_root', type=str, default=None,
help='Directory to save the mesh dumps')
parser.add_argument('--pbr_dump_root', type=str, default=None,
help='Directory to save the pbr dumps')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
os.makedirs(os.path.join(opt.root, 'asset_stats', 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'asset_stats', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'asset_stats','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if 'num_faces' in metadata.columns:
metadata = metadata[metadata['num_faces'].isnull()]
metadata = metadata[(metadata['mesh_dumped'] == True) | (metadata['pbr_dumped'] == True)]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
records = []
with ThreadPoolExecutor(max_workers=opt.max_workers or os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc='Processing objects') as pbar:
def worker(metadatum):
try:
sha256 = metadatum['sha256']
if metadatum['pbr_dumped'] == True:
with open(os.path.join(opt.pbr_dump_root, 'pbr_dumps', f'{sha256}.pickle'), 'rb') as f:
dump = pickle.load(f)
num_faces = 0
num_vertices = 0
for obj in dump['objects']:
if obj['vertices'].size == 0 or obj['faces'].size == 0:
continue
num_faces += obj['faces'].shape[0]
num_vertices += obj['vertices'].shape[0]
num_basecolor_tex = 0
num_metallic_tex = 0
num_roughness_tex = 0
num_alpha_tex = 0
for mat in dump['materials']:
if mat['baseColorTexture'] is not None:
num_basecolor_tex += 1
if mat['metallicTexture'] is not None:
num_metallic_tex += 1
if mat['roughnessTexture'] is not None:
num_roughness_tex += 1
if mat['alphaTexture'] is not None:
num_alpha_tex += 1
record = {
'sha256': sha256,
'num_faces': num_faces,
'num_vertices': num_vertices,
'num_basecolor_tex': num_basecolor_tex,
'num_metallic_tex': num_metallic_tex,
'num_roughness_tex': num_roughness_tex,
'num_alpha_tex': num_alpha_tex,
}
records.append(record)
else:
with open(os.path.join(opt.mesh_dump_root,'mesh_dumps', f'{sha256}.pickle'), 'rb') as f:
dump = pickle.load(f)
num_faces = 0
num_vertices = 0
for obj in dump['objects']:
if obj['vertices'].size == 0 or obj['faces'].size == 0:
continue
num_faces += obj['faces'].shape[0]
num_vertices += obj['vertices'].shape[0]
record = {
'sha256': sha256,
'num_faces': num_faces,
'num_vertices': num_vertices,
}
records.append(record)
pbar.update()
except Exception as e:
print(f'Error processing {sha256}: {e}')
pbar.update()
for metadatum in metadata.to_dict('records'):
executor.submit(worker, metadatum)
executor.shutdown(wait=True)
# save records
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.root, 'asset_stats', 'new_records', f'part_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,242 @@
import argparse, sys, os, math, io
from typing import *
import bpy
import bmesh
from mathutils import Vector, Matrix
import numpy as np
import pickle
"""=============== BLENDER ==============="""
IMPORT_FUNCTIONS: Dict[str, Callable] = {
"obj": bpy.ops.import_scene.obj if bpy.app.version[0] < 4 else bpy.ops.wm.obj_import,
"glb": bpy.ops.import_scene.gltf,
"gltf": bpy.ops.import_scene.gltf,
"usd": bpy.ops.import_scene.usd,
"fbx": bpy.ops.import_scene.fbx,
"stl": bpy.ops.import_mesh.stl if bpy.app.version[0] < 4 else bpy.ops.wm.stl_import,
"usda": bpy.ops.import_scene.usda,
"dae": bpy.ops.wm.collada_import,
"ply": bpy.ops.import_mesh.ply if bpy.app.version[0] < 4 else bpy.ops.wm.ply_import,
"abc": bpy.ops.wm.alembic_import,
"blend": bpy.ops.wm.append,
}
def init_scene() -> None:
"""Resets the scene to a clean state.
Returns:
None
"""
# delete everything
for obj in bpy.data.objects:
bpy.data.objects.remove(obj, do_unlink=True)
# delete all the materials
for material in bpy.data.materials:
bpy.data.materials.remove(material, do_unlink=True)
# delete all the textures
for texture in bpy.data.textures:
bpy.data.textures.remove(texture, do_unlink=True)
# delete all the images
for image in bpy.data.images:
bpy.data.images.remove(image, do_unlink=True)
def load_object(object_path: str) -> None:
"""Loads a model with a supported file extension into the scene.
Args:
object_path (str): Path to the model file.
Raises:
ValueError: If the file extension is not supported.
Returns:
None
"""
file_extension = object_path.split(".")[-1].lower()
if file_extension is None:
raise ValueError(f"Unsupported file type: {object_path}")
if file_extension == "usdz":
# install usdz io package
dirname = os.path.dirname(os.path.realpath(__file__))
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
bpy.ops.preferences.addon_install(filepath=usdz_package)
# enable it
addon_name = "io_scene_usdz"
bpy.ops.preferences.addon_enable(module=addon_name)
# import the usdz
from io_scene_usdz.import_usdz import import_usdz
import_usdz(context, filepath=object_path, materials=True, animations=True)
return None
# load from existing import functions
import_function = IMPORT_FUNCTIONS[file_extension]
print(f"Loading object from {object_path}")
if file_extension == "blend":
import_function(directory=object_path, link=False)
elif file_extension in {"glb", "gltf"}:
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS', bone_heuristic='TEMPERANCE')
else:
import_function(filepath=object_path)
def delete_invisible_objects() -> None:
"""Deletes all invisible objects in the scene.
Returns:
None
"""
to_remove = []
for obj in bpy.context.scene.objects:
if obj.hide_viewport or obj.hide_render:
obj.hide_viewport = False
obj.hide_render = False
obj.hide_select = False
to_remove.append(obj)
for obj in to_remove:
bpy.data.objects.remove(obj, do_unlink=True)
# Delete invisible collections
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
for col in invisible_collections:
bpy.data.collections.remove(col)
def scene_bbox() -> Tuple[Vector, Vector]:
"""Returns the bounding box of the scene.
Taken from Shap-E rendering script
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
Returns:
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
"""
bbox_min = (math.inf,) * 3
bbox_max = (-math.inf,) * 3
found = False
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
for obj in scene_meshes:
found = True
for coord in obj.bound_box:
coord = Vector(coord)
coord = obj.matrix_world @ coord
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
if not found:
raise RuntimeError("no objects in scene to compute bounding box for")
return Vector(bbox_min), Vector(bbox_max)
def normalize_scene() -> Tuple[float, Vector]:
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
at the origin.
Mostly taken from the Point-E / Shap-E rendering script
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
but fix for multiple root objects: (see bug report here:
https://github.com/openai/shap-e/pull/60).
Returns:
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
"""
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
if len(scene_root_objects) > 1:
# create an empty object to be used as a parent for all root objects
scene = bpy.data.objects.new("ParentEmpty", None)
bpy.context.scene.collection.objects.link(scene)
# parent all root objects to the empty object
for obj in scene_root_objects:
obj.parent = scene
else:
scene = scene_root_objects[0]
bbox_min, bbox_max = scene_bbox()
scale = 1 / max(bbox_max - bbox_min)
scene.scale = scene.scale * scale
# Apply scale to matrix_world.
bpy.context.view_layer.update()
bbox_min, bbox_max = scene_bbox()
offset = -(bbox_min + bbox_max) / 2
scene.matrix_world.translation += offset
return scale, offset
def main(arg):
# Initialize context
if arg.object.endswith(".blend"):
delete_invisible_objects()
else:
init_scene()
load_object(arg.object)
print('[INFO] Scene initialized.')
# Normalize scene
scale, offset = normalize_scene()
print('[INFO] Scene normalized.')
# Start dumping
depsgraph = bpy.context.evaluated_depsgraph_get()
scene = bpy.context.scene
output = {
'objects': [],
}
# Dumping meshes
for obj in scene.objects:
if obj.type != 'MESH':
continue
pack = {
"vertices": None,
"faces": None,
}
eval_obj = obj.evaluated_get(depsgraph)
eval_mesh = eval_obj.to_mesh()
bm = bmesh.new()
bm.from_mesh(eval_mesh)
bm.transform(obj.matrix_world)
bmesh.ops.triangulate(bm, faces=bm.faces)
bm.to_mesh(eval_mesh)
bm.free()
pack["vertices"] = np.array([
v.co[:] for v in eval_mesh.vertices
], dtype=np.float32) # (N, 3)
pack["faces"] = np.array([
[eval_mesh.loops[i].vertex_index for i in poly.loop_indices]
for poly in eval_mesh.polygons
], dtype=np.int32) # (F, 3)
output['objects'].append(pack)
# Save output
os.makedirs(os.path.dirname(arg.output_path), exist_ok=True)
with open(arg.output_path, 'wb') as f:
pickle.dump(output, f)
print('[INFO] Output saved to {}.'.format(arg.output_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
parser.add_argument('--output_path', type=str, default='/tmp', help='The path the output will be dumped to.')
argv = sys.argv[sys.argv.index("--") + 1:]
args = parser.parse_args(argv)
main(args)

View File

@@ -0,0 +1,485 @@
import argparse, sys, os, math, io
from typing import *
import bpy
import bmesh
from mathutils import Vector, Matrix
import numpy as np
from PIL import Image
import pickle
"""=============== BLENDER ==============="""
IMPORT_FUNCTIONS: Dict[str, Callable] = {
"obj": bpy.ops.import_scene.obj if bpy.app.version[0] < 4 else bpy.ops.wm.obj_import,
"glb": bpy.ops.import_scene.gltf,
"gltf": bpy.ops.import_scene.gltf,
"usd": bpy.ops.import_scene.usd,
"fbx": bpy.ops.import_scene.fbx,
"stl": bpy.ops.import_mesh.stl if bpy.app.version[0] < 4 else bpy.ops.wm.stl_import,
"usda": bpy.ops.import_scene.usda,
"dae": bpy.ops.wm.collada_import,
"ply": bpy.ops.import_mesh.ply if bpy.app.version[0] < 4 else bpy.ops.wm.ply_import,
"abc": bpy.ops.wm.alembic_import,
"blend": bpy.ops.wm.append,
}
def init_scene() -> None:
"""Resets the scene to a clean state.
Returns:
None
"""
# delete everything
for obj in bpy.data.objects:
bpy.data.objects.remove(obj, do_unlink=True)
# delete all the materials
for material in bpy.data.materials:
bpy.data.materials.remove(material, do_unlink=True)
# delete all the textures
for texture in bpy.data.textures:
bpy.data.textures.remove(texture, do_unlink=True)
# delete all the images
for image in bpy.data.images:
bpy.data.images.remove(image, do_unlink=True)
def load_object(object_path: str) -> None:
"""Loads a model with a supported file extension into the scene.
Args:
object_path (str): Path to the model file.
Raises:
ValueError: If the file extension is not supported.
Returns:
None
"""
file_extension = object_path.split(".")[-1].lower()
if file_extension is None:
raise ValueError(f"Unsupported file type: {object_path}")
if file_extension == "usdz":
# install usdz io package
dirname = os.path.dirname(os.path.realpath(__file__))
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
bpy.ops.preferences.addon_install(filepath=usdz_package)
# enable it
addon_name = "io_scene_usdz"
bpy.ops.preferences.addon_enable(module=addon_name)
# import the usdz
from io_scene_usdz.import_usdz import import_usdz
import_usdz(context, filepath=object_path, materials=True, animations=True)
return None
# load from existing import functions
import_function = IMPORT_FUNCTIONS[file_extension]
print(f"Loading object from {object_path}")
if file_extension == "blend":
import_function(directory=object_path, link=False)
elif file_extension in {"glb", "gltf"}:
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS', bone_heuristic='TEMPERANCE')
else:
import_function(filepath=object_path)
def delete_invisible_objects() -> None:
"""Deletes all invisible objects in the scene.
Returns:
None
"""
to_remove = []
for obj in bpy.context.scene.objects:
if obj.hide_viewport or obj.hide_render:
obj.hide_viewport = False
obj.hide_render = False
obj.hide_select = False
to_remove.append(obj)
for obj in to_remove:
bpy.data.objects.remove(obj, do_unlink=True)
# Delete invisible collections
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
for col in invisible_collections:
bpy.data.collections.remove(col)
def scene_bbox() -> Tuple[Vector, Vector]:
"""Returns the bounding box of the scene.
Taken from Shap-E rendering script
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
Returns:
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
"""
bbox_min = (math.inf,) * 3
bbox_max = (-math.inf,) * 3
found = False
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
for obj in scene_meshes:
found = True
for coord in obj.bound_box:
coord = Vector(coord)
coord = obj.matrix_world @ coord
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
if not found:
raise RuntimeError("no objects in scene to compute bounding box for")
return Vector(bbox_min), Vector(bbox_max)
def normalize_scene() -> Tuple[float, Vector]:
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
at the origin.
Mostly taken from the Point-E / Shap-E rendering script
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
but fix for multiple root objects: (see bug report here:
https://github.com/openai/shap-e/pull/60).
Returns:
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
"""
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
if len(scene_root_objects) > 1:
# create an empty object to be used as a parent for all root objects
scene = bpy.data.objects.new("ParentEmpty", None)
bpy.context.scene.collection.objects.link(scene)
# parent all root objects to the empty object
for obj in scene_root_objects:
obj.parent = scene
else:
scene = scene_root_objects[0]
bbox_min, bbox_max = scene_bbox()
scale = 1 / max(bbox_max - bbox_min)
scene.scale = scene.scale * scale
# Apply scale to matrix_world.
bpy.context.view_layer.update()
bbox_min, bbox_max = scene_bbox()
offset = -(bbox_min + bbox_max) / 2
scene.matrix_world.translation += offset
return scale, offset
# =============== NODE TREE PARSING ===============
def extract_image(tex_node, channels):
image = tex_node.image
pixels = np.array(image.pixels[:])
data = pixels.reshape(image.size[1], image.size[0], -1)
data = data[..., channels]
if data.dtype != np.uint8:
data = np.clip(data, 0.0, 1.0)
data = (data * 255).astype(np.uint8)
if len(data.shape) == 2: # Single channel
pil_image = Image.fromarray(data, mode='L')
elif data.shape[2] == 3:
pil_image = Image.fromarray(data, mode='RGB')
elif data.shape[2] == 4:
pil_image = Image.fromarray(data, mode='RGBA')
else:
raise ValueError("Unsupported channel shape for image")
buffer = io.BytesIO()
pil_image.save(buffer, format='PNG')
png_bytes = buffer.getvalue()
return {
'image': png_bytes,
'interpolation': tex_node.interpolation,
'extension': tex_node.extension,
}
def try_extract_image(link, expected_channel='RGB'):
"""
Tries to extract an image from a texture node link.
Supported sub tree modes:
- RGB:
TEX_IMAGE ->
- R, G, B:
TEX_IMAGE -> SEPARATE_COLOR ->
- A:
TEX_IMAGE ->
"""
assert expected_channel in ['RGB', 'R', 'G', 'B', 'A'], "Unsupported channel"
if expected_channel == 'RGB':
assert link.from_node.type == 'TEX_IMAGE', "Material is not supported"
assert link.from_socket.name == 'Color', "Material is not supported"
tex_node = link.from_node
return extract_image(tex_node, [0, 1, 2])
if expected_channel in ['R', 'G', 'B']:
socket_name = {
'R': 'Red',
'G': 'Green',
'B': 'Blue',
}[expected_channel]
assert link.from_node.type == 'SEPARATE_COLOR' and link.from_node.mode == 'RGB', \
f"Material is not supported, {link.from_node.type}, {link.from_node.mode}"
assert link.from_socket.name == socket_name, "Material is not supported"
sep_node = link.from_node
assert sep_node.inputs[0].is_linked and sep_node.inputs[0].links[0].from_node.type == 'TEX_IMAGE', \
"Material is not supported"
assert sep_node.inputs[0].links[0].from_socket.name == 'Color', "Material is not supported"
tex_node = sep_node.inputs[0].links[0].from_node
channel_index = {
'R': 0,
'G': 1,
'B': 2,
}[expected_channel]
return extract_image(tex_node, channel_index)
if expected_channel == 'A':
assert link.from_node.type == 'TEX_IMAGE', "Material is not supported"
assert link.from_socket.name == 'Alpha', "Material is not supported"
tex_node = link.from_node
return extract_image(tex_node, 3)
def try_extract_factor(link, mode='color'):
"""
Tries to extract a factor from a math node link.
Supported sub tree modes:
- color:
ANY -> MIX(MULTIPLY) ->
- scalar:
ANY -> MATH(MULTIPLY) ->
"""
assert mode in ['color','scalar'], "Unsupported mode"
if mode == 'color':
if link.from_node.type == 'MIX':
mix_node = link.from_node
assert mix_node.data_type == 'RGBA' and mix_node.blend_type == 'MULTIPLY', f"Material is not supported, {mix_node.data_type}, {mix_node.blend_type}"
assert not mix_node.inputs['Factor'].is_linked and mix_node.inputs['Factor'].default_value == 1.0, \
"Material is not supported"
if mix_node.inputs['A'].is_linked:
assert not mix_node.inputs['B'].is_linked, "Material is not supported"
return (list(mix_node.inputs['B'].default_value)[:3], mix_node.inputs['A'].links[0])
else:
assert not mix_node.inputs['A'].is_linked, "Material is not supported"
assert mix_node.inputs['B'].is_linked, "Material is not supported"
return (list(mix_node.inputs['A'].default_value)[:3], mix_node.inputs['B'].links[0])
return ([1.0, 1.0, 1.0], link)
if mode =='scalar':
if link.from_node.type == 'MATH':
math_node = link.from_node
assert math_node.operation == 'MULTIPLY', "Material is not supported"
assert math_node.inputs[0].is_linked, "Material is not supported"
assert not math_node.inputs[1].is_linked, "Material is not supported"
return (math_node.inputs[1].default_value, math_node.inputs[0].links[0])
return (1.0, link)
def try_extract_image_with_factor(link, expected_channel='RGB'):
"""
Tries to extract an image and a factor from a texture node link.
"""
factor, link = try_extract_factor(link, 'color' if expected_channel in ['RGB'] else 'scalar')
image = try_extract_image(link, expected_channel)
return (factor, image)
def main(arg):
# Initialize context
if arg.object.endswith(".blend"):
delete_invisible_objects()
else:
init_scene()
load_object(arg.object)
print('[INFO] Scene initialized.')
# Normalize scene
scale, offset = normalize_scene()
print('[INFO] Scene normalized.')
# Start dumping
depsgraph = bpy.context.evaluated_depsgraph_get()
scene = bpy.context.scene
output = {
'materials': [],
'objects': [],
}
# Dumping materials
for mat in bpy.data.materials:
assert mat.use_nodes == True, "Material is not supported"
pack = {
"baseColorFactor": [1.0, 1.0, 1.0],
"alphaFactor": 1.0,
"metallicFactor": 1.0,
"roughnessFactor": 1.0,
"alphaMode": "OPAQUE",
"alphaCutoff": 0.5,
"baseColorTexture": None,
"alphaTexture": None,
"metallicTexture": None,
"roughnessTexture": None,
}
try:
principled_node = mat.node_tree.nodes.get('Principled BSDF')
assert principled_node is not None, "Material is not supported"
# Handle base color
if not principled_node.inputs['Base Color'].is_linked:
pack["baseColorFactor"] = list(principled_node.inputs['Base Color'].default_value)
else:
link = principled_node.inputs['Base Color'].links[0]
if link.from_node.type == 'RGB':
pack["baseColorFactor"] = list(link.from_node.outputs[0].default_value)
else:
factor, image = try_extract_image_with_factor(link, 'RGB')
pack["baseColorFactor"] = factor
pack["baseColorTexture"] = image
# Handle alpha
if not principled_node.inputs['Alpha'].is_linked:
pack["alphaFactor"] = principled_node.inputs['Alpha'].default_value
if pack["alphaFactor"] < 1.0:
pack["alphaMode"] = "BLEND"
else:
link = principled_node.inputs['Alpha'].links[0]
node = link.from_node
if node.type == 'VALUE':
pack["alphaFactor"] = node.outputs[0].default_value
if pack["alphaFactor"] < 1.0:
pack["alphaMode"] = "BLEND"
else:
pack["alphaMode"] = "BLEND"
if node.type == 'MATH':
if node.operation == 'ROUND':
assert node.inputs[0].is_linked, "Material is not supported"
pack["alphaMode"] = "MASK"
link = node.inputs[0].links[0]
elif node.operation == 'SUBTRACT':
assert node.inputs[0].default_value == 1.0 and \
node.inputs[1].is_linked and \
node.inputs[1].links[0].from_node.type == 'MATH' and \
node.inputs[1].links[0].from_node.operation == 'LESS_THAN', \
"Material is not supported"
assert node.inputs[1].links[0].from_node.inputs[0].is_linked, "Material is not supported"
pack["alphaMode"] = "MASK"
pack["alphaCutoff"] = node.inputs[1].links[0].from_node.inputs[1].default_value
link = node.inputs[1].links[0].from_node.inputs[0].links[0]
factor, image = try_extract_image_with_factor(link, 'A')
pack["alphaFactor"] = factor
pack["alphaTexture"] = image
# Handle metallic
if not principled_node.inputs['Metallic'].is_linked:
pack["metallicFactor"] = principled_node.inputs['Metallic'].default_value
else:
link = principled_node.inputs['Metallic'].links[0]
node = link.from_node
if node.type == 'VALUE':
pack["metallicFactor"] = node.outputs[0].default_value
else:
factor, image = try_extract_image_with_factor(link, 'B')
pack["metallicFactor"] = factor
pack["metallicTexture"] = image
# Handle roughness
if not principled_node.inputs['Roughness'].is_linked:
pack["roughnessFactor"] = principled_node.inputs['Roughness'].default_value
else:
link = principled_node.inputs['Roughness'].links[0]
node = link.from_node
if node.type == 'VALUE':
pack["roughnessFactor"] = node.outputs[0].default_value
else:
factor, image = try_extract_image_with_factor(link, 'G')
pack["roughnessFactor"] = factor
pack["roughnessTexture"] = image
output['materials'].append(pack)
except:
with open(arg.output_path + '_error.txt', 'w') as f:
f.write(str([[n.name] for n in mat.node_tree.nodes]))
raise RuntimeError("Material is not supported")
# Dumping meshes
for obj in scene.objects:
if obj.type != 'MESH':
continue
pack = {
"vertices": None,
"faces": None,
"uvs": None,
"matIDs": None,
}
eval_obj = obj.evaluated_get(depsgraph)
eval_mesh = eval_obj.to_mesh()
bm = bmesh.new()
bm.from_mesh(eval_mesh)
bm.transform(obj.matrix_world)
bmesh.ops.triangulate(bm, faces=bm.faces)
bm.to_mesh(eval_mesh)
bm.free()
pack["vertices"] = np.array([
v.co[:] for v in eval_mesh.vertices
], dtype=np.float32) # (N, 3)
pack["faces"] = np.array([
[eval_mesh.loops[i].vertex_index for i in poly.loop_indices]
for poly in eval_mesh.polygons
], dtype=np.int32) # (F, 3)
pack["normals"] = np.array([
[eval_mesh.loops[i].normal for i in poly.loop_indices]
for poly in eval_mesh.polygons
], dtype=np.float32) # (F, 3, 3)
if eval_mesh.uv_layers.active is not None:
pack["uvs"] = np.array([
[eval_mesh.uv_layers.active.data[i].uv for i in poly.loop_indices]
for poly in eval_mesh.polygons
], dtype=np.float32) # (F, 3, 2)
pack["mat_ids"] = np.array([
bpy.data.materials.find(obj.material_slots[poly.material_index].name)
if len(obj.material_slots) > 0 and obj.material_slots[poly.material_index].material is not None else -1
for poly in eval_mesh.polygons
], dtype=np.int32)
output['objects'].append(pack)
# Save output
os.makedirs(os.path.dirname(arg.output_path), exist_ok=True)
with open(arg.output_path, 'wb') as f:
pickle.dump(output, f)
print('[INFO] Output saved to {}.'.format(arg.output_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
parser.add_argument('--output_path', type=str, default='/tmp', help='The path the output will be dumped to.')
argv = sys.argv[sys.argv.index("--") + 1:]
args = parser.parse_args(argv)
main(args)

View File

@@ -0,0 +1,6 @@
import subprocess
import sys
import ensurepip
ensurepip.bootstrap()
subprocess.check_call([sys.executable, "-m", "pip", "install", "Pillow"])

Binary file not shown.

View File

@@ -0,0 +1,437 @@
import argparse, sys, os, math, re, glob
from typing import *
import bpy
from mathutils import Vector, Matrix
import numpy as np
import json
import glob
"""=============== BLENDER ==============="""
IMPORT_FUNCTIONS: Dict[str, Callable] = {
"obj": bpy.ops.import_scene.obj,
"glb": bpy.ops.import_scene.gltf,
"gltf": bpy.ops.import_scene.gltf,
"usd": bpy.ops.import_scene.usd,
"fbx": bpy.ops.import_scene.fbx,
"stl": bpy.ops.import_mesh.stl,
"usda": bpy.ops.import_scene.usda,
"dae": bpy.ops.wm.collada_import,
"ply": bpy.ops.import_mesh.ply,
"abc": bpy.ops.wm.alembic_import,
"blend": bpy.ops.wm.append,
}
EXT = {
'PNG': 'png',
'JPEG': 'jpg',
'OPEN_EXR': 'exr',
'TIFF': 'tiff',
'BMP': 'bmp',
'HDR': 'hdr',
'TARGA': 'tga'
}
def init_render(engine='CYCLES', resolution=512):
bpy.context.scene.render.engine = engine
bpy.context.scene.render.resolution_x = resolution
bpy.context.scene.render.resolution_y = resolution
bpy.context.scene.render.resolution_percentage = 100
bpy.context.scene.render.image_settings.file_format = 'PNG'
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
bpy.context.scene.render.film_transparent = True
bpy.context.scene.cycles.device = 'GPU'
bpy.context.scene.cycles.samples = 32
bpy.context.scene.cycles.filter_type = 'BOX'
bpy.context.scene.cycles.filter_width = 1
bpy.context.scene.cycles.diffuse_bounces = 1
bpy.context.scene.cycles.glossy_bounces = 1
bpy.context.scene.cycles.transparent_max_bounces = 3
bpy.context.scene.cycles.transmission_bounces = 3
bpy.context.scene.cycles.use_denoising = True
bpy.context.preferences.addons['cycles'].preferences.get_devices()
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
def init_scene() -> None:
"""Resets the scene to a clean state.
Returns:
None
"""
# delete everything
for obj in bpy.data.objects:
bpy.data.objects.remove(obj, do_unlink=True)
# delete all the materials
for material in bpy.data.materials:
bpy.data.materials.remove(material, do_unlink=True)
# delete all the textures
for texture in bpy.data.textures:
bpy.data.textures.remove(texture, do_unlink=True)
# delete all the images
for image in bpy.data.images:
bpy.data.images.remove(image, do_unlink=True)
def init_camera():
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
bpy.context.collection.objects.link(cam)
bpy.context.scene.camera = cam
cam.data.sensor_height = cam.data.sensor_width = 32
cam_constraint = cam.constraints.new(type='TRACK_TO')
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
cam_constraint.up_axis = 'UP_Y'
cam_empty = bpy.data.objects.new("Empty", None)
cam_empty.location = (0, 0, 0)
bpy.context.scene.collection.objects.link(cam_empty)
cam_constraint.target = cam_empty
return cam
def init_uniform_lighting():
# Clear existing lights
bpy.ops.object.select_all(action="DESELECT")
bpy.ops.object.select_by_type(type="LIGHT")
bpy.ops.object.delete()
# Create environment light
if bpy.context.scene.world is None:
world = bpy.data.worlds.new("World")
bpy.context.scene.world = world
else:
world = bpy.context.scene.world
# Enabling nodes
world.use_nodes = True
node_tree = world.node_tree
nodes = node_tree.nodes
links = node_tree.links
# Remove default nodes
for node in nodes:
nodes.remove(node)
# Create background node
bg_node = nodes.new(type="ShaderNodeBackground")
bg_node.inputs["Color"].default_value = (1.0, 1.0, 1.0, 1.0)
bg_node.inputs["Strength"].default_value = 1.0
output_node = nodes.new(type="ShaderNodeOutputWorld")
links.new(bg_node.outputs["Background"], output_node.inputs["Surface"])
def init_random_lighting(camera_dir: np.ndarray) -> None:
# Clear existing lights
bpy.ops.object.select_all(action="DESELECT")
bpy.ops.object.select_by_type(type="LIGHT")
bpy.ops.object.delete()
# Create environment light
if bpy.context.scene.world is None:
world = bpy.data.worlds.new("World")
bpy.context.scene.world = world
else:
world = bpy.context.scene.world
# Enabling nodes
world.use_nodes = True
node_tree = world.node_tree
nodes = node_tree.nodes
links = node_tree.links
# Remove default nodes
for node in nodes:
nodes.remove(node)
# Random place lights
num_lights = np.random.randint(1, 4)
total_strength = 1.5
for i in range(num_lights):
new_light = bpy.data.objects.new(f"Light_{i}", bpy.data.lights.new(f"Light_{i}", type="POINT"))
bpy.context.collection.objects.link(new_light)
new_light_distance = 1 / np.random.uniform(1/100, 1/10)
new_light_dir = np.random.randn(3)
new_light_dir[2] += 0.6
new_light_dir = new_light_dir / np.linalg.norm(new_light_dir)
new_light_location = new_light_dir * new_light_distance
new_light_camera_strength_ratio = max(np.sum(camera_dir * new_light_dir) * 0.5 + 0.5, 0)
new_light_max_energy = total_strength / (np.sum(camera_dir * new_light_dir) * 0.45 + 0.55)
new_light_strength = np.sqrt(np.random.uniform(0.01, 1)) * new_light_max_energy
new_light_camera_strength = new_light_camera_strength_ratio * new_light_strength
total_strength -= new_light_camera_strength
new_light.location = (new_light_location[0], new_light_location[1], new_light_location[2])
new_light.data.color = (1.0, 1.0, 1.0)
new_light.data.energy = new_light_strength * new_light_distance**2 * 31.4
new_light.data.shadow_soft_size = np.random.uniform(0.1, 0.1 * new_light_distance)
# Create background node
bg_node = nodes.new(type="ShaderNodeBackground")
bg_node.inputs["Color"].default_value = (1.0, 1.0, 1.0, 1.0)
bg_node.inputs["Strength"].default_value = total_strength
output_node = nodes.new(type="ShaderNodeOutputWorld")
links.new(bg_node.outputs["Background"], output_node.inputs["Surface"])
def load_object(object_path: str) -> None:
"""Loads a model with a supported file extension into the scene.
Args:
object_path (str): Path to the model file.
Raises:
ValueError: If the file extension is not supported.
Returns:
None
"""
file_extension = object_path.split(".")[-1].lower()
if file_extension is None:
raise ValueError(f"Unsupported file type: {object_path}")
if file_extension == "usdz":
# install usdz io package
dirname = os.path.dirname(os.path.realpath(__file__))
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
bpy.ops.preferences.addon_install(filepath=usdz_package)
# enable it
addon_name = "io_scene_usdz"
bpy.ops.preferences.addon_enable(module=addon_name)
# import the usdz
from io_scene_usdz.import_usdz import import_usdz
import_usdz(context, filepath=object_path, materials=True, animations=True)
return None
# load from existing import functions
import_function = IMPORT_FUNCTIONS[file_extension]
print(f"Loading object from {object_path}")
if file_extension == "blend":
import_function(directory=object_path, link=False)
elif file_extension in {"glb", "gltf"}:
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
else:
import_function(filepath=object_path)
def delete_invisible_objects() -> None:
"""Deletes all invisible objects in the scene.
Returns:
None
"""
# bpy.ops.object.mode_set(mode="OBJECT")
bpy.ops.object.select_all(action="DESELECT")
for obj in bpy.context.scene.objects:
if obj.hide_viewport or obj.hide_render:
obj.hide_viewport = False
obj.hide_render = False
obj.hide_select = False
obj.select_set(True)
bpy.ops.object.delete()
# Delete invisible collections
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
for col in invisible_collections:
bpy.data.collections.remove(col)
def unhide_all_objects() -> None:
"""Unhides all objects in the scene.
Returns:
None
"""
for obj in bpy.context.scene.objects:
obj.hide_set(False)
def convert_to_meshes() -> None:
"""Converts all objects in the scene to meshes.
Returns:
None
"""
bpy.ops.object.select_all(action="DESELECT")
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
for obj in bpy.context.scene.objects:
obj.select_set(True)
bpy.ops.object.convert(target="MESH")
def triangulate_meshes() -> None:
"""Triangulates all meshes in the scene.
Returns:
None
"""
bpy.ops.object.select_all(action="DESELECT")
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
bpy.context.view_layer.objects.active = objs[0]
for obj in objs:
obj.select_set(True)
bpy.ops.object.mode_set(mode="EDIT")
bpy.ops.mesh.reveal()
bpy.ops.mesh.select_all(action="SELECT")
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
bpy.ops.object.mode_set(mode="OBJECT")
bpy.ops.object.select_all(action="DESELECT")
def scene_bbox() -> Tuple[Vector, Vector]:
"""Returns the bounding box of the scene.
Taken from Shap-E rendering script
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
Returns:
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
"""
bbox_min = (math.inf,) * 3
bbox_max = (-math.inf,) * 3
found = False
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
for obj in scene_meshes:
found = True
for coord in obj.bound_box:
coord = Vector(coord)
coord = obj.matrix_world @ coord
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
if not found:
raise RuntimeError("no objects in scene to compute bounding box for")
return Vector(bbox_min), Vector(bbox_max)
def normalize_scene() -> Tuple[float, Vector]:
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
at the origin.
Mostly taken from the Point-E / Shap-E rendering script
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
but fix for multiple root objects: (see bug report here:
https://github.com/openai/shap-e/pull/60).
Returns:
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
"""
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
if len(scene_root_objects) > 1:
# create an empty object to be used as a parent for all root objects
scene = bpy.data.objects.new("ParentEmpty", None)
bpy.context.scene.collection.objects.link(scene)
# parent all root objects to the empty object
for obj in scene_root_objects:
obj.parent = scene
else:
scene = scene_root_objects[0]
bbox_min, bbox_max = scene_bbox()
scale = 1 / max(bbox_max - bbox_min)
scene.scale = scene.scale * scale
# Apply scale to matrix_world.
bpy.context.view_layer.update()
bbox_min, bbox_max = scene_bbox()
offset = -(bbox_min + bbox_max) / 2
scene.matrix_world.translation += offset
bpy.ops.object.select_all(action="DESELECT")
return scale, offset
def get_transform_matrix(obj: bpy.types.Object) -> list:
pos, rt, _ = obj.matrix_world.decompose()
rt = rt.to_matrix()
matrix = []
for ii in range(3):
a = []
for jj in range(3):
a.append(rt[ii][jj])
a.append(pos[ii])
matrix.append(a)
matrix.append([0, 0, 0, 1])
return matrix
def main(arg):
if arg.object.endswith(".blend"):
delete_invisible_objects()
else:
init_scene()
load_object(arg.object)
print('[INFO] Scene initialized.')
# normalize scene
scale, offset = normalize_scene()
print('[INFO] Scene normalized.')
# Initialize camera and lighting
cam = init_camera()
init_uniform_lighting()
print('[INFO] Camera and lighting initialized.')
# ============= Render conditional views =============
init_render(engine=arg.engine, resolution=arg.cond_resolution)
# Create a list of views
to_export = {
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
"scale": scale,
"offset": [offset.x, offset.y, offset.z],
"frames": []
}
views = json.loads(arg.cond_views)
for i, view in enumerate(views):
cam_dir = np.array([
np.cos(view['yaw']) * np.cos(view['pitch']),
np.sin(view['yaw']) * np.cos(view['pitch']),
np.sin(view['pitch'])
])
init_random_lighting(cam_dir)
cam.location = (
view['radius'] * cam_dir[0],
view['radius'] * cam_dir[1],
view['radius'] * cam_dir[2]
)
cam.data.lens = 16 / np.tan(view['fov'] / 2)
bpy.context.scene.render.filepath = os.path.join(arg.cond_output_folder, f'{i:03d}.png')
# Render the scene
bpy.ops.render.render(write_still=True)
bpy.context.view_layer.update()
# Save camera parameters
metadata = {
"file_path": f'{i:03d}.png',
"camera_angle_x": view['fov'],
"transform_matrix": get_transform_matrix(cam)
}
to_export["frames"].append(metadata)
# Save the camera parameters
with open(os.path.join(arg.cond_output_folder, 'transforms.json'), 'w') as f:
json.dump(to_export, f, indent=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
parser.add_argument('--cond_views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
parser.add_argument('--cond_output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
parser.add_argument('--cond_resolution', type=int, default=1024, help='Resolution of the conditional images.')
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
argv = sys.argv[sys.argv.index("--") + 1:]
args = parser.parse_args(argv)
main(args)

219
data_toolkit/build_metadata.py Executable file
View File

@@ -0,0 +1,219 @@
import os
import shutil
import sys
import time
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
def update_metadata(path, opt):
if not os.path.exists(path):
return None
timestamp = str(int(time.time()))
os.makedirs(os.path.join(path, 'merged_records'), exist_ok=True)
os.makedirs(os.path.join(path, 'new_records'), exist_ok=True)
if opt.from_merged_records:
df_files = [f for f in os.listdir(os.path.join(path, 'merged_records')) if f.endswith('.csv')]
df_files = [f for f in df_files if int(f.split('_')[0]) >= opt.record_start]
else:
df_files = [f for f in os.listdir(os.path.join(path, 'new_records')) if f.startswith('part_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(path, 'new_records', f)))
except Exception as e:
print(f"Failed to read {f}: {e}")
if len(df_parts) > 0:
if os.path.exists(os.path.join(path, 'metadata.csv')):
metadata = pd.read_csv(os.path.join(path, 'metadata.csv'))
else:
columns = df_parts[0].columns
metadata = pd.DataFrame(columns=columns)
metadata.set_index('sha256', inplace=True)
for df_part in df_parts:
if 'sha256' in df_part.columns:
df_part.set_index('sha256', inplace=True)
metadata = df_part.combine_first(metadata)
metadata.to_csv(os.path.join(path, 'metadata.csv'))
for f in df_files:
shutil.move(os.path.join(path, 'new_records', f), os.path.join(path, 'merged_records', f'{timestamp}_{f}'))
return metadata
else:
if os.path.exists(os.path.join(path, 'metadata.csv')):
return pd.read_csv(os.path.join(path, 'metadata.csv'))
return None
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--thumbnail_root', type=str, default=None,
help='Directory to save the thumbnail files')
parser.add_argument('--render_cond_root', type=str, default=None,
help='Directory to save the render condition files')
parser.add_argument('--mesh_dump_root', type=str, default=None,
help='Directory to save the mesh files')
parser.add_argument('--pbr_dump_root', type=str, default=None,
help='Directory to save the pbr files')
parser.add_argument('--dual_grid_root', type=str, default=None,
help='Directory to save the dual grid files')
parser.add_argument('--pbr_voxel_root', type=str, default=None,
help='Directory to save the pbr voxel files')
parser.add_argument('--ss_latent_root', type=str, default=None,
help='Directory to save the sparse structure latent files')
parser.add_argument('--shape_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--pbr_latent_root', type=str, default=None,
help='Directory to save the pbr latent files')
parser.add_argument('--field', type=str, default='all',
help='Fields to process, separated by commas')
parser.add_argument('--from_file', action='store_true',
help='Build metadata from file instead of from records of processings.' +
'Useful when some processing fail to generate records but file already exists.')
parser.add_argument('--from_merged_records', action='store_true',
help='Build metadata from merged records')
parser.add_argument('--record_start', type=int)
parser.add_argument('--rebuild', action='store_true',
help='Rebuild metadata from scratch, ignore existing metadata.')
dataset_utils.add_args(parser)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.thumbnail_root = opt.thumbnail_root or opt.root
opt.render_cond_root = opt.render_cond_root or opt.root
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
opt.dual_grid_root = opt.dual_grid_root or opt.root
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
opt.ss_latent_root = opt.ss_latent_root or opt.root
opt.shape_latent_root = opt.shape_latent_root or opt.root
opt.pbr_latent_root = opt.pbr_latent_root or opt.root
os.makedirs(opt.root, exist_ok=True)
opt.field = opt.field.split(',')
# get file list
if os.path.exists(os.path.join(opt.root, 'metadata.csv')):
print('Loading previous metadata...')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv'))
else:
metadata = dataset_utils.get_metadata(**opt)
metadata.to_csv(os.path.join(opt.root, 'metadata.csv'), index=False)
# merge downloaded
downloaded_metadata = update_metadata(os.path.join(opt.download_root, 'raw'), opt)
# merge thumbnails
thumbnail_metadata = update_metadata(os.path.join(opt.thumbnail_root, 'thumbnails'), opt)
# merge aesthetic scores
aesthetic_score_metadata = update_metadata(os.path.join(opt.root, 'aesthetic_scores'), opt)
# merge render conditions
render_cond_metadata = update_metadata(os.path.join(opt.render_cond_root, 'renders_cond'), opt)
# merge mesh dumped
mesh_dumped_metadata = update_metadata(os.path.join(opt.mesh_dump_root, 'mesh_dumps'), opt)
# merge pbr dumped
pbr_dumped_metadata = update_metadata(os.path.join(opt.pbr_dump_root, 'pbr_dumps'), opt)
# merge asset stats
asset_stats_metadata = update_metadata(os.path.join(opt.root, 'asset_stats'), opt)
# merge dual grid
dual_grid_resolutions = []
for dir in os.listdir(opt.dual_grid_root):
if os.path.isdir(os.path.join(opt.dual_grid_root, dir)) and dir.startswith('dual_grid_'):
dual_grid_resolutions.append(int(dir.split('_')[-1]))
dual_grid_metadata = {}
for res in dual_grid_resolutions:
dual_grid_metadata[res] = update_metadata(os.path.join(opt.dual_grid_root, f'dual_grid_{res}'), opt)
# merge pbr voxelized
pbr_voxel_resolutions = []
for dir in os.listdir(opt.pbr_voxel_root):
if os.path.isdir(os.path.join(opt.pbr_voxel_root, dir)) and dir.startswith('pbr_voxels_'):
pbr_voxel_resolutions.append(int(dir.split('_')[-1]))
pbr_voxel_metadata = {}
for res in pbr_voxel_resolutions:
pbr_voxel_metadata[res] = update_metadata(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}'), opt)
# merge ss latents
ss_latent_models = []
if os.path.exists(os.path.join(opt.ss_latent_root, 'ss_latents')):
ss_latent_models = os.listdir(os.path.join(opt.ss_latent_root, 'ss_latents'))
ss_latent_metadata = {}
for model in ss_latent_models:
ss_latent_metadata[model] = update_metadata(os.path.join(opt.ss_latent_root, f'ss_latents/{model}'), opt)
# merge shape latents
shape_latent_models = []
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents')):
shape_latent_models = os.listdir(os.path.join(opt.shape_latent_root, 'shape_latents'))
shape_latent_metadata = {}
for model in shape_latent_models:
shape_latent_metadata[model] = update_metadata(os.path.join(opt.shape_latent_root, f'shape_latents/{model}'), opt)
# merge pbr latents
pbr_latent_models = []
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents')):
pbr_latent_models = os.listdir(os.path.join(opt.pbr_latent_root, 'pbr_latents'))
pbr_latent_metadata = {}
for model in pbr_latent_models:
pbr_latent_metadata[model] = update_metadata(os.path.join(opt.pbr_latent_root, f'pbr_latents/{model}'), opt)
# statistics
num_downloaded = downloaded_metadata['local_path'].count() if downloaded_metadata is not None else 0
with open(os.path.join(opt.root, 'statistics.txt'), 'w') as f:
f.write('Statistics:\n')
f.write(f' - Number of assets: {len(metadata)}\n')
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
if thumbnail_metadata is not None:
f.write(f' - Number of assets with thumbnails: {thumbnail_metadata["thumbnailed"].sum()}\n')
if aesthetic_score_metadata is not None:
f.write(f' - Number of assets with aesthetic scores: {aesthetic_score_metadata["aesthetic_score"].count()}\n')
if render_cond_metadata is not None:
f.write(f' - Number of assets with render conditions: {render_cond_metadata["cond_rendered"].count()}\n')
if mesh_dumped_metadata is not None:
f.write(f' - Number of assets with mesh dumped: {mesh_dumped_metadata["mesh_dumped"].sum()}\n')
if pbr_dumped_metadata is not None:
f.write(f' - Number of assets with PBR dumped: {pbr_dumped_metadata["pbr_dumped"].sum()}\n')
if asset_stats_metadata is not None:
f.write(f' - Number of assets with asset stats: {len(asset_stats_metadata)}\n')
if len(dual_grid_resolutions) != 0:
f.write(f' - Number of assets with dual grid:\n')
for res in dual_grid_resolutions:
if dual_grid_metadata[res] is not None:
f.write(f' - {res}: {dual_grid_metadata[res]["dual_grid_converted"].sum()}\n')
if len(pbr_voxel_resolutions) != 0:
f.write(f' - Number of assets with PBR voxelization:\n')
for res in pbr_voxel_resolutions:
if pbr_voxel_metadata[res] is not None:
f.write(f' - {res}: {pbr_voxel_metadata[res]["pbr_voxelized"].sum()}\n')
if len(ss_latent_models) != 0:
f.write(f' - Number of assets with sparse structure latents:\n')
for model in ss_latent_models:
if ss_latent_metadata[model] is not None:
f.write(f' - {model}: {ss_latent_metadata[model]["ss_latent_encoded"].sum()}\n')
if len(shape_latent_models) != 0:
f.write(f' - Number of assets with shape latents:\n')
for model in shape_latent_models:
if shape_latent_metadata[model] is not None:
f.write(f' - {model}: {shape_latent_metadata[model]["shape_latent_encoded"].sum()}\n')
if len(pbr_latent_models) != 0:
f.write(f' - Number of assets with PBR latents:\n')
for model in pbr_latent_models:
if pbr_latent_metadata[model] is not None:
f.write(f' - {model}: {pbr_latent_metadata[model]["pbr_latent_encoded"].sum()}\n')
with open(os.path.join(opt.root, 'statistics.txt'), 'r') as f:
print(f.read())

64
data_toolkit/download.py Executable file
View File

@@ -0,0 +1,64 @@
import os
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to download the objects')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--check_only', action='store_true',
help='Only check if the objects are already downloaded')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
os.makedirs(opt.root, exist_ok=True)
os.makedirs(opt.download_root, exist_ok=True)
os.makedirs(os.path.join(opt.download_root, 'raw', 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'local_path' in metadata.columns:
metadata = metadata[metadata['local_path'].isna()]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
downloaded = dataset_utils.download(metadata, **opt)
downloaded.to_csv(os.path.join(opt.download_root, 'raw', 'new_records', f'part_{opt.rank}.csv'), index=False)

170
data_toolkit/dual_grid.py Executable file
View File

@@ -0,0 +1,170 @@
import os
import sys
import importlib
import argparse
import pandas as pd
import numpy as np
import torch
import pickle
import o_voxel
from easydict import EasyDict as edict
from functools import partial
def _dual_grid_mesh(file, metadatum, mesh_dump_root, root):
sha256 = metadatum['sha256']
try:
pack = {'sha256': sha256}
data = None
for res in opt.resolution:
need_process = False
# check if already processed
if os.path.exists(os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz')):
try:
info = o_voxel.io.read_vxz_info(os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz'))
pack[f'dual_grid_converted_{res}'] = True
pack[f'dual_grid_size_{res}'] = info['num_voxel']
except Exception as e:
print(f'Error reading {sha256}.vxz: {e}')
need_process = True
else:
need_process = True
# process mesh
if need_process:
if data is None:
with open(os.path.join(mesh_dump_root, 'mesh_dumps', f'{sha256}.pickle'), 'rb') as f:
dump = pickle.load(f)
start = 0
vertices = []
faces = []
for obj in dump['objects']:
if obj['vertices'].size == 0 or obj['faces'].size == 0:
continue
vertices.append(obj['vertices'])
faces.append(obj['faces'] + start)
start += len(obj['vertices'])
vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
vertices_min = vertices.min(dim=0)[0]
vertices_max = vertices.max(dim=0)[0]
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
vertices = (vertices - center) * scale
assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
data = {'vertices': vertices, 'faces': faces}
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
**data,
grid_size=res,
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
face_weight=1.0,
boundary_weight=0.2,
regularization_weight=1e-2,
timing=False,
)
dual_vertices = dual_vertices * res - voxel_indices
assert torch.all(dual_vertices >= -1e-3) and torch.all(dual_vertices <= 1+1e-3), 'dual_vertices out of range'
dual_vertices = torch.clamp(dual_vertices, 0, 1)
dual_vertices = (dual_vertices * 255).type(torch.uint8)
intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
o_voxel.io.write_vxz(
os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz'),
voxel_indices,
{'vertices': dual_vertices, 'intersected': intersected},
)
pack[f'dual_grid_converted_{res}'] = True
pack[f'dual_grid_size_{res}'] = len(dual_vertices)
return pack
except Exception as e:
print(f'Error processing {sha256}: {e}')
return {'sha256': sha256, 'error': str(e)}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--mesh_dump_root', type=str, default=None,
help='Directory to load mesh dumps')
parser.add_argument('--dual_grid_root', type=str, default=None,
help='Directory to save dual grids')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--resolution', type=str, default=256)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.resolution = [int(x) for x in opt.resolution.split(',')]
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
opt.dual_grid_root = opt.dual_grid_root or opt.root
for res in opt.resolution:
os.makedirs(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')).set_index('sha256'))
for res in opt.resolution:
if os.path.exists(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'metadata.csv')):
dual_grid_metadata = pd.read_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'metadata.csv')).set_index('sha256')
dual_grid_metadata = dual_grid_metadata.rename(columns={'dual_grid_converted': f'dual_grid_converted_{res}', 'dual_grid_size': f'dual_grid_size_{res}'})
metadata = metadata.combine_first(dual_grid_metadata)
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['mesh_dumped'] == True]
mask = np.zeros(len(metadata), dtype=bool)
for res in opt.resolution:
if f'dual_grid_converted_{res}' in metadata.columns:
mask |= metadata[f'dual_grid_converted_{res}'] != True
else:
mask[:] = True
break
metadata = metadata[mask]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_dual_grid_mesh, root=opt.dual_grid_root, mesh_dump_root=opt.mesh_dump_root)
dual_grids = dataset_utils.foreach_instance(metadata, None, func, max_workers=opt.max_workers, no_file=True, desc='Dual griding')
if 'error' in dual_grids.columns:
errors = dual_grids[dual_grids[f'error'].notna()]
with open('errors.txt', 'w') as f:
f.write('\n'.join(errors['sha256'].tolist()))
for res in opt.resolution:
if f'dual_grid_converted_{res}' in dual_grids.columns:
dual_grid_metadata = dual_grids[dual_grids[f'dual_grid_converted_{res}'] == True]
if len(dual_grid_metadata) > 0:
dual_grid_metadata = dual_grid_metadata[['sha256', f'dual_grid_converted_{res}', f'dual_grid_size_{res}']]
dual_grid_metadata = dual_grid_metadata.rename(columns={f'dual_grid_converted_{res}': 'dual_grid_converted', f'dual_grid_size_{res}': 'dual_grid_size'})
dual_grid_metadata.to_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'new_records', f'part_{opt.rank}.csv'), index=False)

127
data_toolkit/dump_mesh.py Executable file
View File

@@ -0,0 +1,127 @@
import os
import shutil
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
import numpy as np
import tempfile
BLENDER_LINK = 'https://ftp.halifax.rwth-aachen.de/blender/release/Blender4.5/blender-4.5.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
def _dump_mesh(file_path, metadatum, root):
sha256 = metadatum['sha256']
with tempfile.TemporaryDirectory() as tmp_dir:
temp_path = os.path.join(tmp_dir, f'{sha256}.pickle')
output_path = os.path.join(root, 'mesh_dumps', f'{sha256}.pickle')
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'dump_mesh.py'),
'--',
'--object', os.path.expanduser(file_path),
'--output_path', os.path.expanduser(temp_path)
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL, stderr=DEVNULL)
if os.path.exists(temp_path):
shutil.move(temp_path, output_path)
return {'sha256': sha256, 'mesh_dumped': True}
else:
if os.path.exists(temp_path + '_error.txt'):
with open(temp_path + '_error.txt', 'r') as f:
error_msg = f.read()
raise ValueError(f'Failed to dump mesh. File {file_path}. Error message: {error_msg}')
else:
raise ValueError(f'Failed to dump mesh. File {file_path}.')
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--mesh_dump_root', type=str, default=None,
help='Directory to save the mesh dumps')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
os.makedirs(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'new_records'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'mesh_dumped' in metadata.columns:
metadata = metadata[metadata['mesh_dumped'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256_list = os.listdir(os.path.join(opt.mesh_dump_root, 'mesh_dumps'))
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.pickle')]
for sha256 in sha256_list:
records.append({'sha256': sha256, 'mesh_dumped': True})
print(f'Found {len(sha256_list)} dumped mesh')
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_dump_mesh, root=opt.mesh_dump_root)
mesh_dumped = dataset_utils.foreach_instance(metadata, opt.download_root, func, max_workers=opt.max_workers, desc='Dumping mesh')
mesh_dumped = pd.concat([mesh_dumped, pd.DataFrame.from_records(records)])
mesh_dumped.to_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'new_records', f'part_{opt.rank}.csv'), index=False)

128
data_toolkit/dump_pbr.py Executable file
View File

@@ -0,0 +1,128 @@
import os
import shutil
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
import numpy as np
import tempfile
BLENDER_LINK = 'https://ftp.halifax.rwth-aachen.de/blender/release/Blender4.5/blender-4.5.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
os.system(f'{BLENDER_PATH} -b --python {os.path.join(os.path.dirname(__file__), "blender_script", "install_pillow.py")}')
def _dump_pbr(file_path, metadatum, root):
sha256 = metadatum['sha256']
with tempfile.TemporaryDirectory() as tmp_dir:
temp_path = os.path.join(tmp_dir, f'{sha256}.pickle')
output_path = os.path.join(root, 'pbr_dumps', f'{sha256}.pickle')
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'dump_pbr.py'),
'--',
'--object', os.path.expanduser(file_path),
'--output_path', os.path.expanduser(temp_path)
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL, stderr=DEVNULL)
if os.path.exists(temp_path):
shutil.move(temp_path, output_path)
return {'sha256': sha256, 'pbr_dumped': True}
else:
if os.path.exists(temp_path + '_error.txt'):
with open(temp_path + '_error.txt', 'r') as f:
error_msg = f.read()
raise ValueError(f'Failed to dump PBR. File {file_path}. Error message: {error_msg}')
else:
raise ValueError(f'Failed to dump PBR. File {file_path}.')
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--pbr_dump_root', type=str, default=None,
help='Directory to save the mesh dumps')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
os.makedirs(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'new_records'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'pbr_dumped' in metadata.columns:
metadata = metadata[metadata['pbr_dumped'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256_list = os.listdir(os.path.join(opt.pbr_dump_root, 'pbr_dumps'))
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.pickle')]
for sha256 in sha256_list:
records.append({'sha256': sha256, 'pbr_dumped': True})
print(f'Found {len(sha256_list)} dumped PBRs')
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_dump_pbr, root=opt.pbr_dump_root)
pbr_dumped = dataset_utils.foreach_instance(metadata, opt.download_root, func, max_workers=opt.max_workers, desc='Dumping PBR')
pbr_dumped = pd.concat([pbr_dumped, pd.DataFrame.from_records(records)])
pbr_dumped.to_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'new_records', f'part_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,181 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import json
import argparse
import torch
import numpy as np
import pandas as pd
import o_voxel
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis2.models as models
import trellis2.modules.sparse as sp
torch.set_grad_enabled(False)
def is_valid_sparse_tensor(tensor):
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
def clear_cuda_error():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--pbr_voxel_root', type=str, default=None,
help='Directory to save the pbr voxel files')
parser.add_argument('--pbr_latent_root', type=str, default=None,
help='Directory to save the pbr latent files')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--resolution', type=int, default=1024,
help='Sparse voxel resolution')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str,
help='Root directory of models')
parser.add_argument('--enc_model', type=str,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
opt.pbr_latent_root = opt.pbr_latent_root or opt.root
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name,'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['pbr_voxelized'] == True]
if 'pbr_latent_encoded' in metadata.columns:
metadata = metadata[metadata['pbr_latent_encoded'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
def check_sha256(sha256):
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz')):
coords = np.load(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz'))['coords']
records.append({'sha256': sha256, 'pbr_latent_encoded': True, 'pbr_latent_tokens': coords.shape[0]})
pbar.update()
executor.map(check_sha256, metadata['sha256'].values)
executor.shutdown(wait=True)
existing_sha256 = set(r['sha256'] for r in records)
print(f'Found {len(existing_sha256)} processed objects')
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
print(f'Processing {len(metadata)} objects...')
sha256s = list(metadata['sha256'].values)
load_queue = Queue(maxsize=32)
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
attrs = ['base_color', 'metallic', 'roughness', 'alpha']
coords, attr = o_voxel.io.read_vxz(
os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}', f'{sha256}.vxz'),
num_threads=4
)
feats = torch.concat([attr[k] for k in attrs], dim=-1) / 255.0 * 2 - 1
x = sp.SparseTensor(
feats.float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
load_queue.put((sha256, x))
except Exception as e:
print(f"[Loader Error] {sha256}: {e}")
load_queue.put((sha256, None))
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, 'pbr_latent_encoded': True, 'pbr_latent_tokens': pack['coords'].shape[0]})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
try:
sha256, voxels = load_queue.get()
if voxels is None:
print(f"[Skip] {sha256}: Failed to load input")
continue
num_voxels = voxels.feats.shape[0]
# NaN/Inf
if not (is_valid_sparse_tensor(voxels)):
print(f"[Skip] {sha256}: NaN/Inf in input")
continue
z = encoder(voxels.cuda())
torch.cuda.synchronize()
if not torch.isfinite(z.feats).all():
print(f"[Skip] {sha256}: Non-finite latent in z.feats")
clear_cuda_error()
continue
pack = {
'feats': z.feats.cpu().numpy().astype(np.float32),
'coords': z.coords[:, 1:].cpu().numpy().astype(np.uint8),
}
saver_executor.submit(saver, sha256, pack)
except Exception as e:
print(f"[Error] {sha256} ({num_voxels} voxels): {e}")
clear_cuda_error()
continue
saver_executor.shutdown(wait=True)
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,184 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import json
import argparse
import torch
import numpy as np
import pandas as pd
import o_voxel
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis2.models as models
import trellis2.modules.sparse as sp
torch.set_grad_enabled(False)
def is_valid_sparse_tensor(tensor):
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
def clear_cuda_error():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--dual_grid_root', type=str, default=None,
help='Directory to save the dual grids')
parser.add_argument('--shape_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--resolution', type=int, default=1024,
help='Sparse voxel resolution')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str,
help='Root directory of models')
parser.add_argument('--enc_model', type=str,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.dual_grid_root = opt.dual_grid_root or opt.root
opt.shape_latent_root = opt.shape_latent_root or opt.root
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name,'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['dual_grid_converted'] == True]
if 'shape_latent_encoded' in metadata.columns:
metadata = metadata[metadata['shape_latent_encoded'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
def check_sha256(sha256):
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz')):
coords = np.load(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz'))['coords']
records.append({'sha256': sha256, 'shape_latent_encoded': True, 'shape_latent_tokens': coords.shape[0]})
pbar.update()
executor.map(check_sha256, metadata['sha256'].values)
executor.shutdown(wait=True)
existing_sha256 = set(r['sha256'] for r in records)
print(f'Found {len(existing_sha256)} processed objects')
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
print(f'Processing {len(metadata)} objects...')
sha256s = list(metadata['sha256'].values)
load_queue = Queue(maxsize=32)
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
coords, attr = o_voxel.io.read_vxz(
os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}', f'{sha256}.vxz'),
num_threads=4
)
vertices = sp.SparseTensor(
(attr['vertices'] / 255.0).float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
intersected = vertices.replace(torch.cat([
attr['intersected'] % 2,
attr['intersected'] // 2 % 2,
attr['intersected'] // 4 % 2,
], dim=-1).bool())
load_queue.put((sha256, vertices, intersected))
except Exception as e:
print(f"[Loader Error] {sha256}: {e}")
load_queue.put((sha256, None, None))
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, 'shape_latent_encoded': True, 'shape_latent_tokens': pack['coords'].shape[0]})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
try:
sha256, vertices, intersected = load_queue.get()
if vertices is None or intersected is None:
print(f"[Skip] {sha256}: Failed to load input")
continue
num_voxels = vertices.feats.shape[0]
# NaN/Inf
if not (is_valid_sparse_tensor(vertices) and is_valid_sparse_tensor(intersected)):
print(f"[Skip] {sha256}: NaN/Inf in input")
continue
z = encoder(vertices.cuda(), intersected.cuda())
torch.cuda.synchronize()
if not torch.isfinite(z.feats).all():
print(f"[Skip] {sha256}: Non-finite latent in z.feats")
clear_cuda_error()
continue
pack = {
'feats': z.feats.cpu().numpy().astype(np.float32),
'coords': z.coords[:, 1:].cpu().numpy().astype(np.uint8),
}
saver_executor.submit(saver, sha256, pack)
except Exception as e:
print(f"[Error] {sha256} ({num_voxels} voxels): {e}")
clear_cuda_error()
continue
saver_executor.shutdown(wait=True)
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,163 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import json
import argparse
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis2.models as models
torch.set_grad_enabled(False)
def is_valid_sparse_tensor(tensor):
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
def clear_cuda_error():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--shape_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--ss_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--resolution', type=int, default=64,
help='Sparse voxel resolution')
parser.add_argument('--shape_latent_name', type=str, default=None,
help='Name of the shape latent files')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str,
help='Root directory of models')
parser.add_argument('--enc_model', type=str,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.shape_latent_root = opt.shape_latent_root or opt.root
opt.ss_latent_root = opt.ss_latent_root or opt.root
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name,'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.ss_latent_root,'ss_latents', latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.ss_latent_root,'ss_latents', latent_name,'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['shape_latent_encoded'] == True]
if 'ss_latent_encoded' in metadata.columns:
metadata = metadata[metadata['ss_latent_encoded'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256_list = os.listdir(os.path.join(opt.ss_latent_root, 'ss_latents'))
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.npz')]
for sha256 in sha256_list:
records.append({'sha256': sha256, 'ss_latent_encoded': True})
print(f'Found {len(sha256_list)} processed objects')
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
print(f'Processing {len(metadata)} objects...')
sha256s = list(metadata['sha256'].values)
load_queue = Queue(maxsize=32)
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
coords = np.load(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, f'{sha256}.npz'))['coords']
assert np.all(coords < opt.resolution), f"{sha256}: Invalid coords"
coords = torch.from_numpy(coords).long()
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
load_queue.put((sha256, ss))
except Exception as e:
print(f"[Loader Error] {sha256}: {e}")
load_queue.put((sha256, None))
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, 'ss_latent_encoded': True})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
try:
sha256, ss = load_queue.get()
if ss is None:
print(f"[Skip] {sha256}: Failed to load input")
continue
ss = ss.cuda()[None].float()
z = encoder(ss, sample_posterior=False)
torch.cuda.synchronize()
if not torch.isfinite(z).all():
print(f"[Skip] {sha256}: Non-finite latent")
clear_cuda_error()
continue
pack = {
'z': z[0].cpu().numpy(),
}
saver_executor.submit(saver, sha256, pack)
except Exception as e:
print(f"[Error] {sha256}: {e}")
clear_cuda_error()
continue
saver_executor.shutdown(wait=True)
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)

146
data_toolkit/render_cond.py Normal file
View File

@@ -0,0 +1,146 @@
import os
import json
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import numpy as np
from utils import sphere_hammersley_sequence
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
def _render_cond(file_path, metadatum, root, num_cond_views):
sha256 = metadatum['sha256']
# Build conditional view camera
yaws = []
pitchs = []
offset = (np.random.rand(), np.random.rand())
for i in range(num_cond_views):
y, p = sphere_hammersley_sequence(i, num_cond_views, offset)
yaws.append(y)
pitchs.append(p)
fov_min, fov_max = 10, 70
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
k_min = 1 / radius_max**2
k_max = 1 / radius_min**2
ks = np.random.uniform(k_min, k_max, (1000000,))
radius = [1 / np.sqrt(k) for k in ks]
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
cond_views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render_cond.py'),
'--',
'--object', os.path.expanduser(file_path),
'--cond_views', json.dumps(cond_views),
'--cond_resolution', '1024',
'--cond_output_folder', os.path.join(root, 'renders_cond', sha256),
'--engine', 'CYCLES',
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL, stderr=DEVNULL)
if os.path.exists(os.path.join(root, 'renders_cond', sha256, 'transforms.json')):
return {'sha256': sha256, 'cond_rendered': True}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--render_cond_root', type=str, default=None,
help='Directory to save the mesh dumps')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--num_cond_views', type=int, default=16,
help='Number of conditional views to render')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=8)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.render_cond_root = opt.render_cond_root or opt.root
os.makedirs(os.path.join(opt.render_cond_root, 'renders_cond', 'new_records'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.render_cond_root, 'renders_cond', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.render_cond_root, 'renders_cond', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'cond_rendered' in metadata.columns:
metadata = metadata[(metadata['cond_rendered'] != True)]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
def check_sha256(sha256):
if os.path.exists(os.path.join(opt.render_cond_root, 'renders_cond', sha256, 'transforms.json')):
records.append({'sha256': sha256, 'cond_rendered': True})
pbar.update()
executor.map(check_sha256, metadata['sha256'].values)
executor.shutdown(wait=True)
existing_sha256 = set(r['sha256'] for r in records)
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_render_cond, root=opt.render_cond_root, num_cond_views=opt.num_cond_views)
cond_rendered = dataset_utils.foreach_instance(metadata, opt.render_cond_root, func, max_workers=opt.max_workers, desc='Rendering objects')
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
cond_rendered.to_csv(os.path.join(opt.render_cond_root, 'renders_cond', 'new_records', f'part_{opt.rank}.csv'), index=False)

1
data_toolkit/setup.sh Executable file
View File

@@ -0,0 +1 @@
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub[cli] open_clip_torch

440
data_toolkit/utils.py Executable file
View File

@@ -0,0 +1,440 @@
from typing import *
import hashlib
import numpy as np
import cv2
def get_file_hash(file: str) -> str:
sha256 = hashlib.sha256()
# Read the file from the path
with open(file, "rb") as f:
# Update the hash with the file content
for byte_block in iter(lambda: f.read(4096), b""):
sha256.update(byte_block)
return sha256.hexdigest()
# ===============LOW DISCREPANCY SEQUENCES================
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
u, v = hammersley_sequence(2, n, num_samples)
u += offset[0] / num_samples
v += offset[1]
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
theta = np.arccos(1 - 2 * u) - np.pi / 2
phi = v * 2 * np.pi
return [phi, theta]
# ==============PLY IO===============
import struct
import re
import torch
def read_ply(filename):
"""
Read a PLY file and return vertices, triangle faces, and quad faces.
Args:
filename (str): The file path to read from.
Returns:
vertices (torch.Tensor): Tensor of shape [N, 3] containing vertex positions.
tris (torch.Tensor): Tensor of shape [M, 3] containing triangle face indices (empty if none).
quads (torch.Tensor): Tensor of shape [K, 4] containing quad face indices (empty if none).
"""
with open(filename, 'rb') as f:
# Read the header until 'end_header' is encountered
header_bytes = b""
while True:
line = f.readline()
if not line:
raise ValueError("PLY header not found")
header_bytes += line
if b"end_header" in line:
break
header = header_bytes.decode('utf-8')
# Determine if the file is in ASCII or binary format
is_ascii = "ascii" in header
# Extract the number of vertices and faces from the header using regex
vertex_match = re.search(r'element vertex (\d+)', header)
if vertex_match:
num_vertices = int(vertex_match.group(1))
else:
raise ValueError("Vertex count not found in header")
face_match = re.search(r'element face (\d+)', header)
if face_match:
num_faces = int(face_match.group(1))
else:
raise ValueError("Face count not found in header")
vertices = []
tris = []
quads = []
if is_ascii:
# For ASCII format, read each line of vertex data (each line contains 3 floats)
for _ in range(num_vertices):
line = f.readline().decode('utf-8').strip()
if not line:
continue
parts = line.split()
vertices.append([float(parts[0]), float(parts[1]), float(parts[2])])
# Read face data, where the first number indicates the number of vertices for the face
for _ in range(num_faces):
line = f.readline().decode('utf-8').strip()
if not line:
continue
parts = line.split()
count = int(parts[0])
indices = list(map(int, parts[1:]))
if count == 3:
tris.append(indices)
elif count == 4:
quads.append(indices)
else:
# Skip faces with other numbers of vertices (can be extended as needed)
pass
else:
# For binary format: read directly from the binary stream
# Each vertex consists of 3 floats (12 bytes per vertex)
for _ in range(num_vertices):
data = f.read(12)
if len(data) < 12:
raise ValueError("Insufficient vertex data")
v = struct.unpack('<fff', data)
vertices.append(v)
# Read face data from the binary stream
for _ in range(num_faces):
# First, read 1 byte indicating the number of vertices in the face
count_data = f.read(1)
if len(count_data) < 1:
raise ValueError("Failed to read face vertex count")
count = struct.unpack('<B', count_data)[0]
if count == 3:
data = f.read(12) # 3 * 4 bytes
if len(data) < 12:
raise ValueError("Insufficient data for triangle face")
indices = struct.unpack('<3i', data)
tris.append(indices)
elif count == 4:
data = f.read(16) # 4 * 4 bytes
if len(data) < 16:
raise ValueError("Insufficient data for quad face")
indices = struct.unpack('<4i', data)
quads.append(indices)
else:
# For faces with a different number of vertices, read count*4 bytes
data = f.read(count * 4)
# Skip or extend processing as needed
raise ValueError(f"Unsupported face with {count} vertices")
# Convert lists to torch.Tensor
vertices = torch.tensor(vertices, dtype=torch.float32)
tris = torch.tensor(tris, dtype=torch.int32) if len(tris) > 0 else torch.empty((0, 3), dtype=torch.int32)
quads = torch.tensor(quads, dtype=torch.int32) if len(quads) > 0 else torch.empty((0, 4), dtype=torch.int32)
return vertices, tris, quads
def write_ply(filename, vertices, tris, quads, ascii=False):
"""
Write a mesh to a PLY file, with the option to save in ASCII or binary format.
Args:
filename (str): The filename to write to.
vertices (torch.Tensor): [N, 3] The vertex positions.
tris (torch.Tensor): [M, 3] The triangle indices.
quads (torch.Tensor): [K, 4] The quad indices.
ascii (bool): If True, write in ASCII format. If False, write in binary format.
"""
# Convert torch tensors to numpy arrays
vertices = vertices.numpy()
tris = tris.numpy()
quads = quads.numpy()
# Prepare the header
num_vertices = len(vertices)
num_faces = len(tris) + len(quads)
# Vertex properties
vertex_header = "property float x\nproperty float y\nproperty float z"
# Face properties (the number of vertices per face is variable)
face_header = "property list uchar int vertex_index"
# Start writing the PLY header
header = f"ply\n"
header += f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}\n"
header += f"element vertex {num_vertices}\n"
header += vertex_header + "\n"
header += f"element face {num_faces}\n"
header += face_header + "\n"
header += "end_header\n"
# Open the file for writing
with open(filename, 'wb' if not ascii else 'w') as f:
# Write the header
f.write(header if ascii else header.encode('utf-8'))
# Write the vertex data
if ascii:
for v in vertices:
f.write(f"{v[0]} {v[1]} {v[2]}\n")
else:
for v in vertices:
f.write(struct.pack('<fff', *v))
# Write the face data
if ascii:
for tri in tris:
f.write(f"3 {tri[0]} {tri[1]} {tri[2]}\n")
for quad in quads:
f.write(f"4 {quad[0]} {quad[1]} {quad[2]} {quad[3]}\n")
else:
for tri in tris:
f.write(struct.pack('<B3i', 3, *tri)) # 3 indices for triangle
for quad in quads:
f.write(struct.pack('<B4i', 4, *quad)) # 4 indices for quad
# ==============IMAGE UTILS===============
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
num_images = len(images)
if nrow is None and ncol is None:
if aspect_ratio is not None:
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
else:
nrow = int(np.sqrt(num_images))
ncol = (num_images + nrow - 1) // nrow
elif nrow is None and ncol is not None:
nrow = (num_images + ncol - 1) // ncol
elif nrow is not None and ncol is None:
ncol = (num_images + nrow - 1) // nrow
else:
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
if images[0].ndim == 2:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
else:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
for i, img in enumerate(images):
row = i // ncol
col = i % ncol
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
return grid
def notes_on_image(img, notes=None):
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if notes is not None:
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"):
"""
Draw text on an image of the given resolution. The text is automatically wrapped
and scaled so that it fits completely within the image while preserving any explicit
line breaks and original spacing. Horizontal and vertical alignment can be controlled
via flags.
Parameters:
text (str): The input text. Newline characters and spacing are preserved.
resolution (tuple): The image resolution as (width, height).
max_size (float): The maximum font size.
h_align (str): Horizontal alignment. Options: "left", "center", "right".
v_align (str): Vertical alignment. Options: "top", "center", "bottom".
Returns:
numpy.ndarray: The resulting image (BGR format) with the text drawn.
"""
width, height = resolution
# Create a white background image
img = np.full((height, width, 3), 255, dtype=np.uint8)
# Set margins and compute available drawing area
margin = 10
avail_width = width - 2 * margin
avail_height = height - 2 * margin
# Choose OpenCV font and text thickness
font = cv2.FONT_HERSHEY_SIMPLEX
thickness = 1
# Ratio for additional spacing between lines (relative to the height of "A")
line_spacing_ratio = 0.5
def wrap_line(line, max_width, font, thickness, scale):
"""
Wrap a single line of text into multiple lines such that each line's
width (measured at the given scale) does not exceed max_width.
This function preserves the original spacing by splitting the line into tokens
(words and whitespace) using a regular expression.
Parameters:
line (str): The input text line.
max_width (int): Maximum allowed width in pixels.
font (int): OpenCV font identifier.
thickness (int): Text thickness.
scale (float): The current font scale.
Returns:
List[str]: A list of wrapped lines.
"""
# Split the line into tokens (words and whitespace), preserving spacing
tokens = re.split(r'(\s+)', line)
if not tokens:
return ['']
wrapped_lines = []
current_line = ""
for token in tokens:
candidate = current_line + token
candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0]
if candidate_width <= max_width:
current_line = candidate
else:
# If current_line is empty, the token itself is too wide;
# break the token character by character.
if current_line == "":
sub_token = ""
for char in token:
candidate_char = sub_token + char
if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width:
sub_token = candidate_char
else:
if sub_token:
wrapped_lines.append(sub_token)
sub_token = char
current_line = sub_token
else:
wrapped_lines.append(current_line)
current_line = token
if current_line:
wrapped_lines.append(current_line)
return wrapped_lines
def compute_text_block(scale):
"""
Wrap the entire text (splitting at explicit newline characters) using the
provided scale, and then compute the overall width and height of the text block.
Returns:
wrapped_lines (List[str]): The list of wrapped lines.
block_width (int): Maximum width among the wrapped lines.
block_height (int): Total height of the text block including spacing.
sizes (List[tuple]): A list of (width, height) for each wrapped line.
spacing (int): The spacing between lines (computed from the scaled "A" height).
"""
# Split text by explicit newlines
input_lines = text.splitlines() if text else ['']
wrapped_lines = []
for line in input_lines:
wrapped = wrap_line(line, avail_width, font, thickness, scale)
wrapped_lines.extend(wrapped)
sizes = []
for line in wrapped_lines:
(text_size, _) = cv2.getTextSize(line, font, scale, thickness)
sizes.append(text_size) # (width, height)
block_width = max((w for w, h in sizes), default=0)
# Use the height of "A" (at the current scale) to compute line spacing
base_height = cv2.getTextSize("A", font, scale, thickness)[0][1]
spacing = int(line_spacing_ratio * base_height)
block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0
return wrapped_lines, block_width, block_height, sizes, spacing
# Use binary search to find the maximum scale that allows the text block to fit
lo = 0.001
hi = max_size
eps = 0.001 # convergence threshold
best_scale = lo
best_result = None
while hi - lo > eps:
mid = (lo + hi) / 2
wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid)
# Ensure that both width and height constraints are met
if block_width <= avail_width and block_height <= avail_height:
best_scale = mid
best_result = (wrapped_lines, block_width, block_height, sizes, spacing)
lo = mid # try a larger scale
else:
hi = mid # reduce the scale
if best_result is None:
best_scale = 0.5
best_result = compute_text_block(best_scale)
wrapped_lines, block_width, block_height, sizes, spacing = best_result
# Compute starting y-coordinate based on vertical alignment flag
if v_align == "top":
y_top = margin
elif v_align == "center":
y_top = margin + (avail_height - block_height) // 2
elif v_align == "bottom":
y_top = margin + (avail_height - block_height)
else:
y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag
# For cv2.putText, the y coordinate represents the text baseline;
# so for the first line add its height.
y = y_top + (sizes[0][1] if sizes else 0)
# Draw each line with horizontal alignment based on the flag
for i, line in enumerate(wrapped_lines):
line_width, line_height = sizes[i]
if h_align == "left":
x = margin
elif h_align == "center":
x = margin + (avail_width - line_width) // 2
elif h_align == "right":
x = margin + (avail_width - line_width)
else:
x = margin # default to left if invalid flag
cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA)
y += line_height + spacing
return img
def save_image_with_notes(img, path, notes=None):
"""
Save an image with notes.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy().transpose(1, 2, 0)
if img.dtype == np.float32 or img.dtype == np.float64:
img = np.clip(img * 255, 0, 255).astype(np.uint8)
img = notes_on_image(img, notes)
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

167
data_toolkit/voxelize_pbr.py Executable file
View File

@@ -0,0 +1,167 @@
import os
import copy
import sys
import importlib
import argparse
import pandas as pd
import pickle
import numpy as np
import torch
from easydict import EasyDict as edict
from functools import partial
import o_voxel
def _pbr_voxelize(file, metadatum, pbr_dump_root, root):
sha256 = metadatum['sha256']
try:
pack = {'sha256': sha256}
dump = None
for res in opt.resolution:
need_process = False
# check if already processed
if os.path.exists(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz')):
try:
info = o_voxel.io.read_vxz_info(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz'))
pack[f'pbr_voxelized_{res}'] = True
pack[f'num_pbr_voxels_{res}'] = info['num_voxel']
except Exception as e:
print(f'Error reading {sha256}.vxz: {e}')
need_process = True
else:
need_process = True
# process if necessary
if need_process:
if dump == None:
with open(os.path.join(pbr_dump_root, 'pbr_dumps', f'{sha256}.pickle'), 'rb') as f:
dump = pickle.load(f)
# Fix dump alpha map
for mat in dump['materials']:
if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE':
mat['alphaMode'] = 'BLEND'
dump['materials'].append({
"baseColorFactor": [0.8, 0.8, 0.8],
"alphaFactor": 1.0,
"metallicFactor": 0.0,
"roughnessFactor": 0.5,
"alphaMode": "OPAQUE",
"alphaCutoff": 0.5,
"baseColorTexture": None,
"alphaTexture": None,
"metallicTexture": None,
"roughnessTexture": None,
}) # append default material
dump['objects'] = [
obj for obj in dump['objects']
if obj['vertices'].size != 0 and obj['faces'].size != 0
]
vertices = torch.from_numpy(np.concatenate([obj['vertices'] for obj in dump['objects']], axis=0)).float()
vertices_min = vertices.min(dim=0)[0]
vertices_max = vertices.max(dim=0)[0]
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
for obj in dump['objects']:
obj['vertices'] = (torch.from_numpy(obj['vertices']).float() - center) * scale
obj['vertices'] = obj['vertices'].numpy()
obj['mat_ids'][obj['mat_ids'] == -1] = len(dump['materials']) - 1
assert np.all(obj['mat_ids'] >= 0), 'invalid mat_ids'
assert np.all(obj['vertices'] >= -0.5) and np.all(obj['vertices'] <= 0.5), 'vertices out of range'
coord, attr = o_voxel.convert.blender_dump_to_volumetric_attr(dump, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
mip_level_offset=0, verbose=False, timing=False)
del attr['normal']
del attr['emissive']
o_voxel.io.write_vxz(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz'), coord, attr)
pack[f'pbr_voxelized_{res}'] = True
pack[f'num_pbr_voxels_{res}'] = len(coord)
return pack
except Exception as e:
print(f'Error voxelizing {sha256}: {e}')
return {'sha256': sha256, 'error': str(e)}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--pbr_dump_root', type=str, default=None,
help='Directory to load mesh dumps')
parser.add_argument('--pbr_voxel_root', type=str, default=None,
help='Directory to save voxelized pbr attributes')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--resolution', type=str, default=1024)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.resolution = sorted([int(x) for x in opt.resolution.split(',')], reverse=True)
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
for res in opt.resolution:
os.makedirs(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
for res in opt.resolution:
if os.path.exists(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'metadata.csv')):
pbr_voxel_metadata = pd.read_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}','metadata.csv')).set_index('sha256')
pbr_voxel_metadata = pbr_voxel_metadata.rename(columns={'pbr_voxelized': f'pbr_voxelized_{res}', 'num_pbr_voxels': f'num_pbr_voxels_{res}'})
metadata = metadata.combine_first(pbr_voxel_metadata)
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['pbr_dumped'] == True]
mask = np.zeros(len(metadata), dtype=bool)
for res in opt.resolution:
if f'pbr_voxelized_{res}' in metadata.columns:
mask |= metadata[f'pbr_voxelized_{res}'] != True
else:
mask[:] = True
break
metadata = metadata[mask]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_pbr_voxelize, pbr_dump_root=opt.pbr_dump_root, root=opt.pbr_voxel_root)
pbr_voxelized = dataset_utils.foreach_instance(metadata, None, func, max_workers=opt.max_workers, no_file=True, desc='Voxelizing')
if 'error' in pbr_voxelized.columns:
errors = pbr_voxelized[pbr_voxelized['error'].notna()]
with open('errors.txt', 'w') as f:
f.write('\n'.join(errors['sha256'].tolist()))
for res in opt.resolution:
if f'pbr_voxelized_{res}' in pbr_voxelized.columns:
pbr_voxel_metadata = pbr_voxelized[pbr_voxelized[f'pbr_voxelized_{res}'] == True]
if len(pbr_voxel_metadata) > 0:
pbr_voxel_metadata = pbr_voxel_metadata[['sha256', f'pbr_voxelized_{res}', f'num_pbr_voxels_{res}']]
pbr_voxel_metadata = pbr_voxel_metadata.rename(columns={f'pbr_voxelized_{res}': 'pbr_voxelized', f'num_pbr_voxels_{res}': 'num_pbr_voxels'})
pbr_voxel_metadata.to_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'new_records', f'part_{opt.rank}.csv'), index=False)

158
train.py Normal file
View File

@@ -0,0 +1,158 @@
import os
import sys
import json
import glob
import argparse
from easydict import EasyDict as edict
import torch
import torch.multiprocessing as mp
import numpy as np
import random
from trellis2 import models, datasets, trainers
from trellis2.utils.dist_utils import setup_dist
def find_ckpt(cfg):
# Load checkpoint
cfg['load_ckpt'] = None
if cfg.load_dir != '':
if cfg.ckpt == 'latest':
files = glob.glob(os.path.join(cfg.load_dir, 'ckpts', 'misc_*.pt'))
if len(files) != 0:
cfg.load_ckpt = max([
int(os.path.basename(f).split('step')[-1].split('.')[0])
for f in files
])
elif cfg.ckpt == 'none':
cfg.load_ckpt = None
else:
cfg.load_ckpt = int(cfg.ckpt)
return cfg
def setup_rng(rank):
torch.manual_seed(rank)
torch.cuda.manual_seed_all(rank)
np.random.seed(rank)
random.seed(rank)
def get_model_summary(model):
model_summary = 'Parameters:\n'
model_summary += '=' * 128 + '\n'
model_summary += f'{"Name":<{72}}{"Shape":<{32}}{"Type":<{16}}{"Grad"}\n'
num_params = 0
num_trainable_params = 0
for name, param in model.named_parameters():
model_summary += f'{name:<{72}}{str(param.shape):<{32}}{str(param.dtype):<{16}}{param.requires_grad}\n'
num_params += param.numel()
if param.requires_grad:
num_trainable_params += param.numel()
model_summary += '\n'
model_summary += f'Number of parameters: {num_params}\n'
model_summary += f'Number of trainable parameters: {num_trainable_params}\n'
return model_summary
def main(local_rank, cfg):
# Set up distributed training
rank = cfg.node_rank * cfg.num_gpus + local_rank
world_size = cfg.num_nodes * cfg.num_gpus
if world_size > 1:
setup_dist(rank, local_rank, world_size, cfg.master_addr, cfg.master_port)
# Seed rngs
setup_rng(rank)
# Load data
dataset = getattr(datasets, cfg.dataset.name)(cfg.data_dir, **cfg.dataset.args)
# Build model
model_dict = {
name: getattr(models, model.name)(**model.args).cuda()
for name, model in cfg.models.items()
}
# Model summary
if rank == 0:
for name, backbone in model_dict.items():
model_summary = get_model_summary(backbone)
print(f'\n\nBackbone: {name}\n' + model_summary)
with open(os.path.join(cfg.output_dir, f'{name}_model_summary.txt'), 'w') as fp:
print(model_summary, file=fp)
# Build trainer
trainer = getattr(trainers, cfg.trainer.name)(model_dict, dataset, **cfg.trainer.args, output_dir=cfg.output_dir, load_dir=cfg.load_dir, step=cfg.load_ckpt)
# Train
if not cfg.tryrun:
if cfg.profile:
trainer.profile()
else:
trainer.run()
if __name__ == '__main__':
# Arguments and config
parser = argparse.ArgumentParser()
## config
parser.add_argument('--config', type=str, required=True, help='Experiment config file')
## io and resume
parser.add_argument('--output_dir', type=str, required=True, help='Output directory')
parser.add_argument('--load_dir', type=str, default='', help='Load directory, default to output_dir')
parser.add_argument('--ckpt', type=str, default='latest', help='Checkpoint step to resume training, default to latest')
parser.add_argument('--data_dir', type=str, default='./data/', help='Data directory')
parser.add_argument('--auto_retry', type=int, default=3, help='Number of retries on error')
## dubug
parser.add_argument('--tryrun', action='store_true', help='Try run without training')
parser.add_argument('--profile', action='store_true', help='Profile training')
## multi-node and multi-gpu
parser.add_argument('--num_nodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--node_rank', type=int, default=0, help='Node rank')
parser.add_argument('--num_gpus', type=int, default=-1, help='Number of GPUs per node, default to all')
parser.add_argument('--master_addr', type=str, default='localhost', help='Master address for distributed training')
parser.add_argument('--master_port', type=str, default='12345', help='Port for distributed training')
opt = parser.parse_args()
opt.load_dir = opt.load_dir if opt.load_dir != '' else opt.output_dir
opt.num_gpus = torch.cuda.device_count() if opt.num_gpus == -1 else opt.num_gpus
## Load config
config = json.load(open(opt.config, 'r'))
## Combine arguments and config
cfg = edict()
cfg.update(opt.__dict__)
cfg.update(config)
print('\n\nConfig:')
print('=' * 80)
print(json.dumps(cfg.__dict__, indent=4))
# Prepare output directory
if cfg.node_rank == 0:
os.makedirs(cfg.output_dir, exist_ok=True)
## Save command and config
with open(os.path.join(cfg.output_dir, 'command.txt'), 'w') as fp:
print(' '.join(['python'] + sys.argv), file=fp)
with open(os.path.join(cfg.output_dir, 'config.json'), 'w') as fp:
json.dump(config, fp, indent=4)
# Run
if cfg.auto_retry == 0:
cfg = find_ckpt(cfg)
if cfg.num_gpus > 1:
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
else:
main(0, cfg)
else:
for rty in range(cfg.auto_retry):
try:
cfg = find_ckpt(cfg)
if cfg.num_gpus > 1:
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
else:
main(0, cfg)
break
except Exception as e:
print(f'Error: {e}')
print(f'Retrying ({rty + 1}/{cfg.auto_retry})...')

View File

@@ -10,10 +10,10 @@ import utils3d
from .components import StandardDatasetBase
from ..modules import sparse as sp
from ..renderers import VoxelRenderer
from ..representations.mesh import Voxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
from ..representations import Voxel
from ..representations.mesh import MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
from ..utils.data_utils import load_balanced_group_indices
from ..utils.mesh_utils import subdivide_to_size
def is_power_of_two(n: int) -> bool:
@@ -71,7 +71,7 @@ class SparseVoxelPbrVisMixin:
origin=[-0.5, -0.5, -0.5],
voxel_size=1/self.resolution,
coords=x[i].coords[:, 1:].contiguous(),
attrs=attr,
attrs=None,
layout={
'color': slice(0, 3),
}
@@ -81,7 +81,7 @@ class SparseVoxelPbrVisMixin:
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
attr = x[i].feats[:, self.layout[k]].expand(-1, 3)
res = renderer.render(rep, ext, intr)
res = renderer.render(rep, ext, intr, colors_overwrite=attr)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images[k].append(image)

View File

@@ -16,12 +16,12 @@ class SLatShapeVisMixin(SLatVisMixin):
return
if self.slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
cfg['models']['decoder']['args']['resolution'] = self.resolution
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_slat_dec)
decoder.set_resolution(self.resolution)
self.slat_dec = decoder.cuda().eval()
@torch.no_grad()
@@ -72,7 +72,7 @@ class SLatShape(SLatShapeVisMixin, SLat):
min_aesthetic_score: float = 5.0,
max_tokens: int = 32768,
normalization: Optional[dict] = None,
pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
):

View File

@@ -1,14 +1,17 @@
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
import json
from typing import *
import numpy as np
import torch
import cv2
from .. import models
from .components import StandardDatasetBase, ImageConditionedMixin
from ..modules.sparse import SparseTensor, sparse_cat
from ..representations import MeshWithVoxel
from ..renderers import PbrMeshRenderer, EnvMap
from ..utils.data_utils import load_balanced_group_indices
from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics
from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics
class SLatPbrVisMixin:
@@ -47,12 +50,12 @@ class SLatPbrVisMixin:
if self.shape_slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.shape_slat_dec_path, 'config.json'), 'r'))
cfg['models']['decoder']['args']['resolution'] = self.resolution
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.shape_slat_dec_path, 'ckpts', f'decoder_{self.shape_slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_shape_slat_dec)
decoder.set_resolution(self.resolution)
self.shape_slat_dec = decoder.cuda().eval()
def _delete_slat_dec(self):
@@ -71,7 +74,7 @@ class SLatPbrVisMixin:
z = z * self.pbr_slat_std.to(z.device) + self.pbr_slat_mean.to(z.device)
for i in range(0, z.shape[0], batch_size):
mesh, subs = self.shape_slat_dec(shape_z[i:i+batch_size], return_subs=True)
vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs)
vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs) * 0.5 + 0.5
reps.extend([
MeshWithVoxel(
m.vertices, m.faces,
@@ -101,18 +104,32 @@ class SLatPbrVisMixin:
exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
# render
renderer = get_renderer(reps[0])
images = {k: [] for k in self.layout}
renderer = PbrMeshRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.near = 1
renderer.rendering_options.far = 100
renderer.rendering_options.ssaa = 2
renderer.rendering_options.peel_layers = 8
envmap = EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
))
images = {}
for representation in reps:
image = {k: torch.zeros(3, 1024, 1024).cuda() for k in self.layout}
image = {}
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(representation, ext, intr, return_types=['attr'])
for k in self.layout:
image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res[k]
for k in self.layout:
res = renderer.render(representation, ext, intr, envmap=envmap)
for k, v in res.items():
if k not in images:
images[k] = []
if k not in image:
image[k] = torch.zeros(3, 1024, 1024).cuda()
image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = v
for k in images.keys():
images[k].append(image[k])
for k in self.layout:
for k in images.keys():
images[k] = torch.stack(images[k], dim=0)
return images
@@ -156,7 +173,7 @@ class SLatPbr(SLatPbrVisMixin, StandardDatasetBase):
self.min_aesthetic_score = min_aesthetic_score
self.max_tokens = max_tokens
self.full_pbr = full_pbr
self.value_range = (-1, 1)
self.value_range = (0, 1)
super().__init__(
roots,

View File

@@ -87,8 +87,8 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
intersected_logits = h.replace(h.feats[..., 3:6])
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(flexible_dual_grid_to_mesh(
h.coords[:, 1:], v.feats, i.feats, q.feats,
mesh = [Mesh(*flexible_dual_grid_to_mesh(
v.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=True
@@ -101,7 +101,7 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
intersected = h.replace(h.feats[..., 3:6] > 0)
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(*flexible_dual_grid_to_mesh(
h.coords[:, 1:], v.feats, i.feats, q.feats,
v.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=False

View File

@@ -250,9 +250,19 @@ class PbrMeshRenderer:
ssaa = self.rendering_options["ssaa"]
if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
return edict(
shaded=torch.full((4, resolution, resolution), 0.5, dtype=torch.float32, device=self.device),
out_dict = edict(
normal=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
mask=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
base_color=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
metallic=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
roughness=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
alpha=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
clay=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
)
for i, k in enumerate(envmap.keys()):
shaded_key = f"shaded_{k}" if k != '' else "shaded"
out_dict[shaded_key] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device)
return out_dict
rays_o, rays_d = utils3d.torch.get_image_rays(
extrinsics, intrinsics, resolution * ssaa, resolution * ssaa

View File

@@ -138,6 +138,7 @@ class ImageConditionedMixin:
"""
with dist_utils.local_master_first():
self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {}))
self.image_cond_model.cuda()
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: