mirror of
https://github.com/microsoft/TRELLIS.2.git
synced 2026-04-02 02:27:08 -04:00
Release Training Code
This commit is contained in:
113
README.md
113
README.md
@@ -50,7 +50,7 @@ Data processing is streamlined for instant conversions that are fully **renderin
|
||||
- [x] Release pretrained checkpoints (4B)
|
||||
- [x] Hugging Face Spaces demo
|
||||
- [x] Release shape-conditioned texture generation inference code
|
||||
- [ ] Release training code (Current schdule: before 12/31/2025)
|
||||
- [x] Release training code
|
||||
|
||||
|
||||
## 🛠️ Installation
|
||||
@@ -186,6 +186,117 @@ Then, you can access the demo at the address shown in the terminal.
|
||||
|
||||
Please refer to the [example_texturing.py](example_texturing.py) for an example of how to generate PBR textures for a given 3D shape. Also, you can use the [app_texturing.py](app_texturing.py) to run a web demo for PBR texture generation.
|
||||
|
||||
|
||||
## 🏋️ Training
|
||||
|
||||
We provide the full training codebase, enabling users to train **TRELLIS.2** from scratch or fine-tune it on custom datasets.
|
||||
|
||||
### 1. Data Preparation
|
||||
|
||||
Before training, raw 3D assets must be converted into the **O-Voxel** representation. This process includes mesh conversion, compact structured latent generation, and metadata preparation.
|
||||
|
||||
> 📂 **Please refer to [data_toolkit/README.md](data_toolkit/README.md) for detailed instructions on data preprocessing and dataset organization.**
|
||||
|
||||
### 2. Running Training
|
||||
|
||||
Training is managed through the `train.py` script, which accepts multiple command-line arguments to configure experiments:
|
||||
|
||||
* `--config`: Path to the experiment configuration file.
|
||||
* `--output_dir`: Directory for training outputs.
|
||||
* `--load_dir`: Directory to load checkpoints from (defaults to `output_dir`).
|
||||
* `--ckpt`: Checkpoint step to resume from (defaults to the latest).
|
||||
* `--data_dir`: Dataset path or a JSON string specifying dataset locations.
|
||||
* `--auto_retry`: Number of automatic retries upon failure.
|
||||
* `--tryrun`: Perform a dry run without actual training.
|
||||
* `--profile`: Enable training profiling.
|
||||
* `--num_nodes`: Number of nodes for distributed training.
|
||||
* `--node_rank`: Rank of the current node.
|
||||
* `--num_gpus`: Number of GPUs per node (defaults to all available GPUs).
|
||||
* `--master_addr`: Master node address for distributed training.
|
||||
* `--master_port`: Port for distributed training communication.
|
||||
|
||||
|
||||
### SC-VAE Training
|
||||
|
||||
|
||||
To train the shape SC-VAE, run:
|
||||
|
||||
```sh
|
||||
python train.py \
|
||||
--config configs/scvae/shape_vae_next_dc_f16c32_fp16.json \
|
||||
--output_dir results/shape_vae_next_dc_f16c32_fp16 \
|
||||
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"mesh_dump\": \"datasets/ObjaverseXL_sketchfab/mesh_dumps\", \"dual_grid\": \"datasets/ObjaverseXL_sketchfab/dual_grid_256\", \"asset_stats\": \"datasets/ObjaverseXL_sketchfab/asset_stats\"}}"
|
||||
```
|
||||
|
||||
This command trains the shape SC-VAE on the **Objaverse-XL** dataset using the `shape_vae_next_dc_f16c32_fp16.json` configuration. Training outputs will be saved to `results/shape_vae_next_dc_f16c32_fp16`.
|
||||
|
||||
The dataset is specified as a JSON string, where each dataset entry includes:
|
||||
|
||||
* `base`: Root directory of the dataset.
|
||||
* `mesh_dump`: Directory containing preprocessed mesh dumps.
|
||||
* `dual_grid`: Directory with precomputed dual-grid representations.
|
||||
* `asset_stats`: Directory containing precomputed asset statistics.
|
||||
|
||||
To fine-tune the model at a higher resolution, use the `shape_vae_next_dc_f16c32_fp16_ft_512.json` configuration. Remember to update the `finetune_ckpt` field and adjust the dataset paths accordingly.
|
||||
|
||||
|
||||
To train the texture SC-VAE, run:
|
||||
|
||||
```sh
|
||||
python train.py \
|
||||
--config configs/scvae/tex_vae_next_dc_f16c32_fp16.json \
|
||||
--output_dir results/tex_vae_next_dc_f16c32_fp16 \
|
||||
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"pbr_dump\": \"datasets/ObjaverseXL_sketchfab/pbr_dumps\", \"pbr_voxel\": \"datasets/ObjaverseXL_sketchfab/pbr_voxels_256\", \"asset_stats\": \"datasets/ObjaverseXL_sketchfab/asset_stats\"}}"
|
||||
```
|
||||
|
||||
|
||||
### Flow Model Training
|
||||
|
||||
To train the sparse structure flow model, run:
|
||||
|
||||
```sh
|
||||
python train.py \
|
||||
--config configs/gen/ss_flow_img_dit_1_3B_64_bf16.json \
|
||||
--output_dir results/ss_flow_img_dit_1_3B_64_bf16 \
|
||||
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"ss_latent\": \"datasets/ObjaverseXL_sketchfab/ss_latents/ss_enc_conv3d_16l8_fp16_64\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
|
||||
```
|
||||
|
||||
This command trains the sparse-structure flow model on the **Objaverse-XL** dataset using the specified configuration file. Outputs are saved to `results/ss_flow_img_dit_1_3B_64_bf16`.
|
||||
|
||||
The dataset configuration includes:
|
||||
|
||||
* `base`: Root dataset directory.
|
||||
* `ss_latent`: Directory containing precomputed sparse-structure latents.
|
||||
* `render_cond`: Directory containing conditional rendering images.
|
||||
|
||||
|
||||
The second- and third-stage flow models for shape and texture generation can be trained using the following configurations:
|
||||
|
||||
* Shape flow: `slat_flow_img2shape_dit_1_3B_512_bf16.json`
|
||||
* Texture flow: `slat_flow_imgshape2tex_dit_1_3B_512_bf16.json`
|
||||
|
||||
Example commands:
|
||||
|
||||
```sh
|
||||
# Shape flow model
|
||||
python train.py \
|
||||
--config configs/gen/slat_flow_img2shape_dit_1_3B_512_bf16.json \
|
||||
--output_dir results/slat_flow_img2shape_dit_1_3B_512_bf16 \
|
||||
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"shape_latent\": \"datasets/ObjaverseXL_sketchfab/shape_latents/shape_enc_next_dc_f16c32_fp16_512\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
|
||||
|
||||
# Texture flow model
|
||||
python train.py \
|
||||
--config configs/gen/slat_flow_imgshape2tex_dit_1_3B_512_bf16.json \
|
||||
--output_dir results/slat_flow_imgshape2tex_dit_1_3B_512_bf16 \
|
||||
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"shape_latent\": \"datasets/ObjaverseXL_sketchfab/shape_latents/shape_enc_next_dc_f16c32_fp16_512\", \"pbr_latent\": \"datasets/ObjaverseXL_sketchfab/pbr_latents/tex_enc_next_dc_f16c32_fp16_512\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
|
||||
```
|
||||
|
||||
Higher-resolution fine-tuning can be performed by updating the `finetune_ckpt` field in the following configuration files and adjusting the dataset paths accordingly:
|
||||
|
||||
* `slat_flow_img2shape_dit_1_3B_512_bf16_ft1024.json`
|
||||
* `slat_flow_imgshape2tex_dit_1_3B_512_bf16_ft1024.json`
|
||||
|
||||
|
||||
## 🧩 Related Packages
|
||||
|
||||
TRELLIS.2 is built upon several specialized high-performance packages developed by our team:
|
||||
|
||||
98
configs/gen/slat_flow_img2shape_dit_1_3B_512_bf16.json
Executable file
98
configs/gen/slat_flow_img2shape_dit_1_3B_512_bf16.json
Executable file
@@ -0,0 +1,98 @@
|
||||
{
|
||||
"models": {
|
||||
"denoiser": {
|
||||
"name": "ElasticSLatFlowModel",
|
||||
"args": {
|
||||
"resolution": 32,
|
||||
"in_channels": 32,
|
||||
"out_channels": 32,
|
||||
"model_channels": 1536,
|
||||
"cond_channels": 1024,
|
||||
"num_blocks": 30,
|
||||
"num_heads": 12,
|
||||
"mlp_ratio": 5.3334,
|
||||
"pe_mode": "rope",
|
||||
"share_mod": true,
|
||||
"initialization": "scaled",
|
||||
"qk_rms_norm": true,
|
||||
"qk_rms_norm_cross": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "ImageConditionedSLatShape",
|
||||
"args": {
|
||||
"resolution": 512,
|
||||
"image_size": 512,
|
||||
"min_aesthetic_score": 4.5,
|
||||
"max_tokens": 8192,
|
||||
"normalization": {
|
||||
"mean": [
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
],
|
||||
"std": [
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
]
|
||||
},
|
||||
"pretrained_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 8,
|
||||
"batch_split": 2,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.01,
|
||||
"betas": [0.9, 0.95],
|
||||
"eps": 1e-8
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"mix_precision_mode": "amp",
|
||||
"mix_precision_dtype": "bfloat16",
|
||||
"elastic": {
|
||||
"name": "LinearMemoryController",
|
||||
"args": {
|
||||
"target_ratio": 0.75,
|
||||
"max_mem_ratio_start": 0.5
|
||||
}
|
||||
},
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"i_log": 500,
|
||||
"i_sample": 10000,
|
||||
"i_save": 10000,
|
||||
"p_uncond": 0.1,
|
||||
"t_schedule": {
|
||||
"name": "uniform",
|
||||
"args": {}
|
||||
},
|
||||
"sigma_min": 1e-5,
|
||||
"image_cond_model": {
|
||||
"name": "DinoV3FeatureExtractor",
|
||||
"args": {
|
||||
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
|
||||
"image_size": 512
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
101
configs/gen/slat_flow_img2shape_dit_1_3B_512_bf16_ft1024.json
Executable file
101
configs/gen/slat_flow_img2shape_dit_1_3B_512_bf16_ft1024.json
Executable file
@@ -0,0 +1,101 @@
|
||||
{
|
||||
"models": {
|
||||
"denoiser": {
|
||||
"name": "ElasticSLatFlowModel",
|
||||
"args": {
|
||||
"resolution": 64,
|
||||
"in_channels": 32,
|
||||
"out_channels": 32,
|
||||
"model_channels": 1536,
|
||||
"cond_channels": 1024,
|
||||
"num_blocks": 30,
|
||||
"num_heads": 12,
|
||||
"mlp_ratio": 5.3334,
|
||||
"pe_mode": "rope",
|
||||
"share_mod": true,
|
||||
"initialization": "scaled",
|
||||
"qk_rms_norm": true,
|
||||
"qk_rms_norm_cross": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "ImageConditionedSLatShape",
|
||||
"args": {
|
||||
"resolution": 1024,
|
||||
"image_size": 1024,
|
||||
"min_aesthetic_score": 4.5,
|
||||
"max_tokens": 32768,
|
||||
"normalization": {
|
||||
"mean": [
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
],
|
||||
"std": [
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
]
|
||||
},
|
||||
"pretrained_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 2,
|
||||
"batch_split": 1,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 2e-5,
|
||||
"weight_decay": 0.01,
|
||||
"betas": [0.9, 0.95],
|
||||
"eps": 1e-8
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"mix_precision_mode": "amp",
|
||||
"mix_precision_dtype": "bfloat16",
|
||||
"elastic": {
|
||||
"name": "LinearMemoryController",
|
||||
"args": {
|
||||
"target_ratio": 0.75,
|
||||
"max_mem_ratio_start": 0.25
|
||||
}
|
||||
},
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"finetune_ckpt": {
|
||||
"denoiser": "PATH_TO_512_CKPT"
|
||||
},
|
||||
"i_log": 500,
|
||||
"i_sample": 1000,
|
||||
"i_save": 1000,
|
||||
"p_uncond": 0.1,
|
||||
"t_schedule": {
|
||||
"name": "uniform",
|
||||
"args": {}
|
||||
},
|
||||
"sigma_min": 1e-5,
|
||||
"image_cond_model": {
|
||||
"name": "DinoV3FeatureExtractor",
|
||||
"args": {
|
||||
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
|
||||
"image_size": 1024
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
119
configs/gen/slat_flow_imgshape2tex_dit_1_3B_512_bf16.json
Executable file
119
configs/gen/slat_flow_imgshape2tex_dit_1_3B_512_bf16.json
Executable file
@@ -0,0 +1,119 @@
|
||||
{
|
||||
"models": {
|
||||
"denoiser": {
|
||||
"name": "ElasticSLatFlowModel",
|
||||
"args": {
|
||||
"resolution": 32,
|
||||
"in_channels": 64,
|
||||
"out_channels": 32,
|
||||
"model_channels": 1536,
|
||||
"cond_channels": 1024,
|
||||
"num_blocks": 30,
|
||||
"num_heads": 12,
|
||||
"mlp_ratio": 5.3334,
|
||||
"pe_mode": "rope",
|
||||
"share_mod": true,
|
||||
"initialization": "scaled",
|
||||
"qk_rms_norm": true,
|
||||
"qk_rms_norm_cross": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "ImageConditionedSLatPbr",
|
||||
"args": {
|
||||
"resolution": 512,
|
||||
"image_size": 512,
|
||||
"min_aesthetic_score": 4.5,
|
||||
"max_tokens": 8192,
|
||||
"pbr_slat_normalization": {
|
||||
"mean": [
|
||||
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
|
||||
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
|
||||
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
|
||||
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
|
||||
],
|
||||
"std": [
|
||||
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
|
||||
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
|
||||
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
|
||||
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
|
||||
]
|
||||
},
|
||||
"shape_slat_normalization": {
|
||||
"mean": [
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
],
|
||||
"std": [
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
]
|
||||
},
|
||||
"attrs": [
|
||||
"base_color",
|
||||
"metallic",
|
||||
"roughness",
|
||||
"alpha"
|
||||
],
|
||||
"pretrained_pbr_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16",
|
||||
"pretrained_shape_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 8,
|
||||
"batch_split": 2,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.01,
|
||||
"betas": [0.9, 0.95],
|
||||
"eps": 1e-8
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"mix_precision_mode": "amp",
|
||||
"mix_precision_dtype": "bfloat16",
|
||||
"elastic": {
|
||||
"name": "LinearMemoryController",
|
||||
"args": {
|
||||
"target_ratio": 0.75,
|
||||
"max_mem_ratio_start": 0.5
|
||||
}
|
||||
},
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"i_log": 500,
|
||||
"i_sample": 10000,
|
||||
"i_save": 10000,
|
||||
"p_uncond": 0.1,
|
||||
"t_schedule": {
|
||||
"name": "uniform",
|
||||
"args": {}
|
||||
},
|
||||
"sigma_min": 1e-5,
|
||||
"image_cond_model": {
|
||||
"name": "DinoV3FeatureExtractor",
|
||||
"args": {
|
||||
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
|
||||
"image_size": 512
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
120
configs/gen/slat_flow_imgshape2tex_dit_1_3B_512_bf16_ft1024.json
Executable file
120
configs/gen/slat_flow_imgshape2tex_dit_1_3B_512_bf16_ft1024.json
Executable file
@@ -0,0 +1,120 @@
|
||||
{
|
||||
"models": {
|
||||
"denoiser": {
|
||||
"name": "ElasticSLatFlowModel",
|
||||
"args": {
|
||||
"resolution": 32,
|
||||
"in_channels": 64,
|
||||
"out_channels": 32,
|
||||
"model_channels": 1536,
|
||||
"cond_channels": 1024,
|
||||
"num_blocks": 30,
|
||||
"num_heads": 12,
|
||||
"mlp_ratio": 5.3334,
|
||||
"pe_mode": "rope",
|
||||
"share_mod": true,
|
||||
"initialization": "scaled",
|
||||
"qk_rms_norm": true,
|
||||
"qk_rms_norm_cross": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "ImageConditionedSLatPbr",
|
||||
"args": {
|
||||
"resolution": 1024,
|
||||
"image_size": 1024,
|
||||
"min_aesthetic_score": 4.5,
|
||||
"max_tokens": 32768,
|
||||
"full_pbr": true,
|
||||
"pbr_slat_normalization": {
|
||||
"mean": [
|
||||
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
|
||||
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
|
||||
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
|
||||
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
|
||||
],
|
||||
"std": [
|
||||
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
|
||||
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
|
||||
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
|
||||
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
|
||||
]
|
||||
},
|
||||
"shape_slat_normalization": {
|
||||
"mean": [
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
],
|
||||
"std": [
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
]
|
||||
},
|
||||
"attrs": [
|
||||
"base_color",
|
||||
"metallic",
|
||||
"roughness",
|
||||
"alpha"
|
||||
],
|
||||
"pretrained_pbr_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16",
|
||||
"pretrained_shape_slat_dec": "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 2,
|
||||
"batch_split": 1,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 2e-5,
|
||||
"weight_decay": 0.01,
|
||||
"betas": [0.9, 0.95],
|
||||
"eps": 1e-8
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"mix_precision_mode": "amp",
|
||||
"mix_precision_dtype": "bfloat16",
|
||||
"elastic": {
|
||||
"name": "LinearMemoryController",
|
||||
"args": {
|
||||
"target_ratio": 0.75,
|
||||
"max_mem_ratio_start": 0.25
|
||||
}
|
||||
},
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"i_log": 500,
|
||||
"i_sample": 1000,
|
||||
"i_save": 1000,
|
||||
"p_uncond": 0.1,
|
||||
"t_schedule": {
|
||||
"name": "uniform",
|
||||
"args": {}
|
||||
},
|
||||
"sigma_min": 1e-5,
|
||||
"image_cond_model": {
|
||||
"name": "DinoV3FeatureExtractor",
|
||||
"args": {
|
||||
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
|
||||
"image_size": 1024
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
78
configs/gen/ss_flow_img_dit_1_3B_64_bf16.json
Executable file
78
configs/gen/ss_flow_img_dit_1_3B_64_bf16.json
Executable file
@@ -0,0 +1,78 @@
|
||||
{
|
||||
"models": {
|
||||
"denoiser": {
|
||||
"name": "SparseStructureFlowModel",
|
||||
"args": {
|
||||
"resolution": 16,
|
||||
"in_channels": 8,
|
||||
"out_channels": 8,
|
||||
"model_channels": 1536,
|
||||
"cond_channels": 1024,
|
||||
"num_blocks": 30,
|
||||
"num_heads": 12,
|
||||
"mlp_ratio": 5.3334,
|
||||
"pe_mode": "rope",
|
||||
"share_mod": true,
|
||||
"initialization": "scaled",
|
||||
"qk_rms_norm": true,
|
||||
"qk_rms_norm_cross": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "ImageConditionedSparseStructureLatent",
|
||||
"args": {
|
||||
"min_aesthetic_score": 4.5,
|
||||
"image_size": 512,
|
||||
"pretrained_ss_dec": "microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "ImageConditionedFlowMatchingCFGTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 8,
|
||||
"batch_split": 4,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.01,
|
||||
"betas": [0.9, 0.95],
|
||||
"eps": 1e-8
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"mix_precision_mode": "amp",
|
||||
"mix_precision_dtype": "bfloat16",
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"i_log": 500,
|
||||
"i_sample": 10000,
|
||||
"i_save": 10000,
|
||||
"p_uncond": 0.1,
|
||||
"t_schedule": {
|
||||
"name": "logitNormal",
|
||||
"args": {
|
||||
"mean": 1.0,
|
||||
"std": 1.0
|
||||
}
|
||||
},
|
||||
"sigma_min": 1e-5,
|
||||
"image_cond_model": {
|
||||
"name": "DinoV3FeatureExtractor",
|
||||
"args": {
|
||||
"model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m",
|
||||
"image_size": 512
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
134
configs/scvae/shape_vae_next_dc_f16c32_fp16.json
Executable file
134
configs/scvae/shape_vae_next_dc_f16c32_fp16.json
Executable file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"models": {
|
||||
"encoder": {
|
||||
"name": "FlexiDualGridVaeEncoder",
|
||||
"args": {
|
||||
"model_channels": [64, 128, 256, 512, 1024],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [0, 4, 8, 16, 4],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"down_block_type": [
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
}
|
||||
],
|
||||
"use_fp16": true
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"name": "FlexiDualGridVaeDecoder",
|
||||
"args": {
|
||||
"resolution": 256,
|
||||
"model_channels": [1024, 512, 256, 128, 64],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [4, 16, 8, 4, 0],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"up_block_type": [
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
}
|
||||
],
|
||||
"use_fp16": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "FlexiDualGridDataset",
|
||||
"args": {
|
||||
"resolution": 256,
|
||||
"max_active_voxels": 1000000,
|
||||
"max_num_faces": 1000000,
|
||||
"min_aesthetic_score": 4.5
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "ShapeVaeTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 8,
|
||||
"batch_split": 2,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.0
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"fp16_mode": "inflat_all",
|
||||
"fp16_scale_growth": 0.001,
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"i_log": 500,
|
||||
"i_sample": 10000,
|
||||
"i_save": 10000,
|
||||
"lambda_subdiv": 0.1,
|
||||
"lambda_intersected": 0.1,
|
||||
"lambda_vertice": 1e-2,
|
||||
"lambda_mask": 1,
|
||||
"lambda_depth": 10,
|
||||
"lambda_normal": 1,
|
||||
"lambda_kl": 1e-6,
|
||||
"lambda_ssim": 0.2,
|
||||
"lambda_lpips": 0.2,
|
||||
"camera_randomization_config": {
|
||||
"radius_range": [2, 100]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
140
configs/scvae/shape_vae_next_dc_f16c32_fp16_ft_512.json
Executable file
140
configs/scvae/shape_vae_next_dc_f16c32_fp16_ft_512.json
Executable file
@@ -0,0 +1,140 @@
|
||||
{
|
||||
"models": {
|
||||
"encoder": {
|
||||
"name": "FlexiDualGridVaeEncoder",
|
||||
"args": {
|
||||
"model_channels": [64, 128, 256, 512, 1024],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [0, 4, 8, 16, 4],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"down_block_type": [
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
}
|
||||
],
|
||||
"use_fp16": true
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"name": "FlexiDualGridVaeDecoder",
|
||||
"args": {
|
||||
"resolution": 512,
|
||||
"model_channels": [1024, 512, 256, 128, 64],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [4, 16, 8, 4, 0],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"up_block_type": [
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
}
|
||||
],
|
||||
"use_fp16": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "FlexiDualGridDataset",
|
||||
"args": {
|
||||
"resolution": 512,
|
||||
"max_active_voxels": 1000000,
|
||||
"max_num_faces": 1000000,
|
||||
"min_aesthetic_score": 4.5
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "ShapeVaeTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 4,
|
||||
"batch_split": 2,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 1e-5,
|
||||
"weight_decay": 0.0
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"fp16_mode": "inflat_all",
|
||||
"fp16_scale_growth": 0.001,
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"finetune_ckpt": {
|
||||
"encoder": "PATH_TO_ENCODER_CKPT",
|
||||
"decoder": "PATH_TO_DECODER_CKPT"
|
||||
},
|
||||
"snapshot_batch_size": 1,
|
||||
"i_log": 500,
|
||||
"i_sample": 10000,
|
||||
"i_save": 10000,
|
||||
"lambda_subdiv": 0.1,
|
||||
"lambda_intersected": 0.1,
|
||||
"lambda_vertice": 1e-2,
|
||||
"lambda_mask": 1,
|
||||
"lambda_depth": 10,
|
||||
"lambda_normal": 1,
|
||||
"lambda_kl": 1e-6,
|
||||
"lambda_ssim": 0.2,
|
||||
"lambda_lpips": 0.2,
|
||||
"render_resolution": 1024,
|
||||
"camera_randomization_config": {
|
||||
"radius_range": [2, 100]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
134
configs/scvae/tex_vae_next_dc_f16c32_fp16.json
Executable file
134
configs/scvae/tex_vae_next_dc_f16c32_fp16.json
Executable file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"models": {
|
||||
"encoder": {
|
||||
"name": "SparseUnetVaeEncoder",
|
||||
"args": {
|
||||
"in_channels": 6,
|
||||
"model_channels": [64, 128, 256, 512, 1024],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [0, 4, 8, 16, 4],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"down_block_type": [
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
}
|
||||
],
|
||||
"use_fp16": true
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"name": "SparseUnetVaeDecoder",
|
||||
"args": {
|
||||
"out_channels": 6,
|
||||
"model_channels": [1024, 512, 256, 128, 64],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [4, 16, 8, 4, 0],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"up_block_type": [
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": false
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
}
|
||||
],
|
||||
"use_fp16": true,
|
||||
"pred_subdiv": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "SparseVoxelPbrDataset",
|
||||
"args": {
|
||||
"resolution": 256,
|
||||
"min_aesthetic_score": 4.5,
|
||||
"max_active_voxels": 1000000,
|
||||
"max_num_faces": 1000000,
|
||||
"with_mesh": false,
|
||||
"attrs": [
|
||||
"base_color",
|
||||
"metallic",
|
||||
"roughness",
|
||||
"alpha"
|
||||
]
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "PbrVaeTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 8,
|
||||
"batch_split": 1,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.0
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"fp16_mode": "inflat_all",
|
||||
"fp16_scale_growth": 0.001,
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"i_log": 500,
|
||||
"i_sample": 10000,
|
||||
"i_save": 10000,
|
||||
"lambda_kl": 1e-6,
|
||||
"loss_type": "l1",
|
||||
"lambda_render": 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
144
configs/scvae/tex_vae_next_dc_f16c32_fp16_ft_512.json
Executable file
144
configs/scvae/tex_vae_next_dc_f16c32_fp16_ft_512.json
Executable file
@@ -0,0 +1,144 @@
|
||||
{
|
||||
"models": {
|
||||
"encoder": {
|
||||
"name": "SparseUnetVaeEncoder",
|
||||
"args": {
|
||||
"in_channels": 6,
|
||||
"model_channels": [64, 128, 256, 512, 1024],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [0, 4, 8, 16, 4],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"down_block_type": [
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d",
|
||||
"SparseResBlockS2C3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
}
|
||||
],
|
||||
"use_fp16": true
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"name": "SparseUnetVaeDecoder",
|
||||
"args": {
|
||||
"out_channels": 6,
|
||||
"model_channels": [1024, 512, 256, 128, 64],
|
||||
"latent_channels": 32,
|
||||
"num_blocks": [4, 16, 8, 4, 0],
|
||||
"block_type": [
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d",
|
||||
"SparseConvNeXtBlock3d"
|
||||
],
|
||||
"up_block_type": [
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d",
|
||||
"SparseResBlockC2S3d"
|
||||
],
|
||||
"block_args": [
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
},
|
||||
{
|
||||
"use_checkpoint": true
|
||||
}
|
||||
],
|
||||
"use_fp16": true,
|
||||
"pred_subdiv": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"name": "SparseVoxelPbrDataset",
|
||||
"args": {
|
||||
"resolution": 512,
|
||||
"min_aesthetic_score": 4.5,
|
||||
"max_active_voxels": 1000000,
|
||||
"max_num_faces": 1000000,
|
||||
"attrs": [
|
||||
"base_color",
|
||||
"metallic",
|
||||
"roughness",
|
||||
"alpha"
|
||||
]
|
||||
}
|
||||
},
|
||||
"trainer": {
|
||||
"name": "PbrVaeTrainer",
|
||||
"args": {
|
||||
"max_steps": 1000000,
|
||||
"batch_size_per_gpu": 4,
|
||||
"batch_split": 2,
|
||||
"optimizer": {
|
||||
"name": "AdamW",
|
||||
"args": {
|
||||
"lr": 1e-5,
|
||||
"weight_decay": 0.0
|
||||
}
|
||||
},
|
||||
"ema_rate": [
|
||||
0.9999
|
||||
],
|
||||
"fp16_mode": "inflat_all",
|
||||
"fp16_scale_growth": 0.001,
|
||||
"grad_clip": {
|
||||
"name": "AdaptiveGradClipper",
|
||||
"args": {
|
||||
"max_norm": 1.0,
|
||||
"clip_percentile": 95
|
||||
}
|
||||
},
|
||||
"finetune_ckpt": {
|
||||
"encoder": "PATH_TO_ENCODER_CKPT",
|
||||
"decoder": "PATH_TO_DECODER_CKPT"
|
||||
},
|
||||
"snapshot_batch_size": 1,
|
||||
"render_resolution": 512,
|
||||
"i_log": 500,
|
||||
"i_sample": 10000,
|
||||
"i_save": 10000,
|
||||
"lambda_kl": 1e-6,
|
||||
"lambda_render": 1.0,
|
||||
"loss_type": "l1",
|
||||
"lambda_ssim": 0.2,
|
||||
"lambda_lpips": 0.2,
|
||||
"camera_randomization_config": {
|
||||
"radius_range": [2, 100]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
172
data_toolkit/README.md
Normal file
172
data_toolkit/README.md
Normal file
@@ -0,0 +1,172 @@
|
||||
# Dataset Preparation Toolkit
|
||||
|
||||
This toolkit provides a comprehensive pipeline for preparing 3D datasets, including downloading, processing, voxelizing, and latent encoding for SC-VAE and Flow Model training.
|
||||
|
||||
### Step 1: Install Dependencies
|
||||
|
||||
Initialize the environment and install necessary dependencies:
|
||||
|
||||
```bash
|
||||
. ./data_toolkit/setup.sh
|
||||
```
|
||||
|
||||
### Step 2: Initialize Metadata
|
||||
|
||||
Before processing, load the dataset metadata.
|
||||
|
||||
```bash
|
||||
python data_toolkit/build_metadata.py <SUBSET> --root <ROOT> [--source <SOURCE>]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `SUBSET`: Target dataset subset. Options: `ObjaverseXL`, `ABO`, `HSSD`, `TexVerse` (Training sets); `SketchfabPicked`, `Toys4k` (Test sets).
|
||||
- `ROOT`: Root directory to save the data.
|
||||
- `SOURCE`: Data source (Required if `SUBSET` is `ObjaverseXL`). Options: `sketchfab`, `github`.
|
||||
|
||||
**Example:**
|
||||
Load metadata for `ObjaverseXL` (sketchfab) and save to `datasets/ObjaverseXL_sketchfab`:
|
||||
```bash
|
||||
python data_toolkit/build_metadata.py ObjaverseXL --source sketchfab --root datasets/ObjaverseXL_sketchfab
|
||||
```
|
||||
|
||||
### Step 3: Download Data
|
||||
|
||||
Download the 3D assets to the local storage.
|
||||
|
||||
```bash
|
||||
python data_toolkit/download.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `RANK` / `WORLD_SIZE`: Parameters for multi-node distributed downloading.
|
||||
|
||||
**Example:**
|
||||
To download the `ObjaverseXL` subset:
|
||||
|
||||
> **Note:** The example below sets a large `WORLD_SIZE` (160,000) for demonstration purposes, meaning only a tiny fraction of the dataset will be downloaded by this single process.
|
||||
|
||||
```bash
|
||||
python data_toolkit/download.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab --world_size 160000
|
||||
```
|
||||
|
||||
*Attention: Some datasets may require an interactive Hugging Face login or manual steps. Please follow any on-screen instructions.*
|
||||
|
||||
**Update Metadata:**
|
||||
After downloading, update the metadata registry:
|
||||
```bash
|
||||
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
```
|
||||
|
||||
### Step 4: Process Mesh and PBR Textures
|
||||
|
||||
Standardize 3D assets by dumping mesh and PBR textures.
|
||||
*Note: This process utilizes the CPU.*
|
||||
|
||||
```bash
|
||||
# Dump Meshes
|
||||
python data_toolkit/dump_mesh.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
|
||||
|
||||
# Dump PBR Textures
|
||||
python data_toolkit/dump_pbr.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
|
||||
|
||||
# Get statisitics of the asset
|
||||
python asset_stats.py --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>]
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python data_toolkit/dump_mesh.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
python data_toolkit/dump_pbr.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
python asset_stats.py --root datasets/ObjaverseXL_sketchfab
|
||||
```
|
||||
|
||||
**Update Metadata:**
|
||||
```bash
|
||||
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
```
|
||||
|
||||
### Step 5: Convert to O-Voxels
|
||||
|
||||
Convert the processed meshes and textures into O-Voxels format.
|
||||
*Note: This process utilizes the CPU.*
|
||||
|
||||
```bash
|
||||
python data_toolkit/dual_grid.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
|
||||
|
||||
python data_toolkit/voxelize_pbr.py <SUBSET> --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `RESOLUTION`: Target resolutions for O-Voxels, comma-separated (e.g., `256,512,1024`). Default is `256`.
|
||||
|
||||
**Example:**
|
||||
Convert `ObjaverseXL` to resolutions 256, 512, and 1024:
|
||||
```bash
|
||||
python data_toolkit/dual_grid.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab --resolution 256,512,1024
|
||||
python data_toolkit/voxelize_pbr.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab --resolution 256,512,1024
|
||||
```
|
||||
|
||||
|
||||
### At this point, the dataset is ready for SC-VAE Training
|
||||
|
||||
### Step 6: Encode Latents
|
||||
|
||||
Encode sparse structures into latents to train the first-stage generator.
|
||||
|
||||
```bash
|
||||
# 1. Encode Shape Latents
|
||||
python data_toolkit/encode_shape_latent.py --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
|
||||
|
||||
# 2. Encode PBR Latents
|
||||
python data_toolkit/encode_pbr_latent.py --root <ROOT> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <RESOLUTION>]
|
||||
|
||||
# 3. Update Metadata (Required before next step)
|
||||
python data_toolkit/build_metadata.py <SUBSET> --root <ROOT>
|
||||
|
||||
# 4. Encode Sparse Structure (SS) Latents
|
||||
python data_toolkit/encode_ss_latent.py --root <ROOT> --shape_latent_name <SHAPE_LATENT_NAME> [--rank <RANK> --world_size <WORLD_SIZE>] [--resolution <SS_RESOLUTION>]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `RESOLUTION`: Input O-Voxel resolution. Default is `1024`.
|
||||
- `SS_RESOLUTION`: Resolution for sparse structures. Default is `64`.
|
||||
- `SHAPE_LATENT_NAME`: The specific version name of the shape latent.
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python data_toolkit/encode_shape_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 512
|
||||
python data_toolkit/encode_pbr_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 512
|
||||
python data_toolkit/encode_shape_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 1024
|
||||
python data_toolkit/encode_pbr_latent.py --root datasets/ObjaverseXL_sketchfab --resolution 1024
|
||||
|
||||
# Update metadata
|
||||
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
|
||||
# Encode SS Latents
|
||||
python data_toolkit/encode_ss_latent.py --root datasets/ObjaverseXL_sketchfab --shape_latent_name shape_enc_next_dc_f16c32_fp16_1024 --resolution 64
|
||||
|
||||
# Final Metadata Update
|
||||
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
```
|
||||
|
||||
### Step 7: Render Image Conditions
|
||||
|
||||
Render multi-view images to train the image-conditioned generator.
|
||||
*Note: This process may utilize the CPU.*
|
||||
|
||||
```bash
|
||||
python data_toolkit/render_cond.py <SUBSET> --root <ROOT> [--num_views <NUM_VIEWS>] [--rank <RANK> --world_size <WORLD_SIZE>]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `NUM_VIEWS`: Number of views to render per asset. Default is `16`.
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python data_toolkit/render_cond.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
```
|
||||
|
||||
**Final Metadata Update:**
|
||||
```bash
|
||||
python data_toolkit/build_metadata.py ObjaverseXL --root datasets/ObjaverseXL_sketchfab
|
||||
```
|
||||
132
data_toolkit/asset_stats.py
Normal file
132
data_toolkit/asset_stats.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import argparse
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--mesh_dump_root', type=str, default=None,
|
||||
help='Directory to save the mesh dumps')
|
||||
parser.add_argument('--pbr_dump_root', type=str, default=None,
|
||||
help='Directory to save the pbr dumps')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
parser.add_argument('--max_workers', type=int, default=0)
|
||||
opt = parser.parse_args()
|
||||
opt = edict(vars(opt))
|
||||
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
|
||||
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
|
||||
|
||||
os.makedirs(os.path.join(opt.root, 'asset_stats', 'new_records'), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'asset_stats', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'asset_stats','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
if 'num_faces' in metadata.columns:
|
||||
metadata = metadata[metadata['num_faces'].isnull()]
|
||||
metadata = metadata[(metadata['mesh_dumped'] == True) | (metadata['pbr_dumped'] == True)]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
records = []
|
||||
with ThreadPoolExecutor(max_workers=opt.max_workers or os.cpu_count()) as executor, \
|
||||
tqdm(total=len(metadata), desc='Processing objects') as pbar:
|
||||
def worker(metadatum):
|
||||
try:
|
||||
sha256 = metadatum['sha256']
|
||||
if metadatum['pbr_dumped'] == True:
|
||||
with open(os.path.join(opt.pbr_dump_root, 'pbr_dumps', f'{sha256}.pickle'), 'rb') as f:
|
||||
dump = pickle.load(f)
|
||||
|
||||
num_faces = 0
|
||||
num_vertices = 0
|
||||
for obj in dump['objects']:
|
||||
if obj['vertices'].size == 0 or obj['faces'].size == 0:
|
||||
continue
|
||||
num_faces += obj['faces'].shape[0]
|
||||
num_vertices += obj['vertices'].shape[0]
|
||||
|
||||
num_basecolor_tex = 0
|
||||
num_metallic_tex = 0
|
||||
num_roughness_tex = 0
|
||||
num_alpha_tex = 0
|
||||
for mat in dump['materials']:
|
||||
if mat['baseColorTexture'] is not None:
|
||||
num_basecolor_tex += 1
|
||||
if mat['metallicTexture'] is not None:
|
||||
num_metallic_tex += 1
|
||||
if mat['roughnessTexture'] is not None:
|
||||
num_roughness_tex += 1
|
||||
if mat['alphaTexture'] is not None:
|
||||
num_alpha_tex += 1
|
||||
|
||||
record = {
|
||||
'sha256': sha256,
|
||||
'num_faces': num_faces,
|
||||
'num_vertices': num_vertices,
|
||||
'num_basecolor_tex': num_basecolor_tex,
|
||||
'num_metallic_tex': num_metallic_tex,
|
||||
'num_roughness_tex': num_roughness_tex,
|
||||
'num_alpha_tex': num_alpha_tex,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
with open(os.path.join(opt.mesh_dump_root,'mesh_dumps', f'{sha256}.pickle'), 'rb') as f:
|
||||
dump = pickle.load(f)
|
||||
|
||||
num_faces = 0
|
||||
num_vertices = 0
|
||||
for obj in dump['objects']:
|
||||
if obj['vertices'].size == 0 or obj['faces'].size == 0:
|
||||
continue
|
||||
num_faces += obj['faces'].shape[0]
|
||||
num_vertices += obj['vertices'].shape[0]
|
||||
|
||||
record = {
|
||||
'sha256': sha256,
|
||||
'num_faces': num_faces,
|
||||
'num_vertices': num_vertices,
|
||||
}
|
||||
records.append(record)
|
||||
pbar.update()
|
||||
except Exception as e:
|
||||
print(f'Error processing {sha256}: {e}')
|
||||
pbar.update()
|
||||
|
||||
for metadatum in metadata.to_dict('records'):
|
||||
executor.submit(worker, metadatum)
|
||||
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
# save records
|
||||
records = pd.DataFrame.from_records(records)
|
||||
records.to_csv(os.path.join(opt.root, 'asset_stats', 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
242
data_toolkit/blender_script/dump_mesh.py
Executable file
242
data_toolkit/blender_script/dump_mesh.py
Executable file
@@ -0,0 +1,242 @@
|
||||
import argparse, sys, os, math, io
|
||||
from typing import *
|
||||
import bpy
|
||||
import bmesh
|
||||
from mathutils import Vector, Matrix
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
|
||||
"""=============== BLENDER ==============="""
|
||||
|
||||
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
||||
"obj": bpy.ops.import_scene.obj if bpy.app.version[0] < 4 else bpy.ops.wm.obj_import,
|
||||
"glb": bpy.ops.import_scene.gltf,
|
||||
"gltf": bpy.ops.import_scene.gltf,
|
||||
"usd": bpy.ops.import_scene.usd,
|
||||
"fbx": bpy.ops.import_scene.fbx,
|
||||
"stl": bpy.ops.import_mesh.stl if bpy.app.version[0] < 4 else bpy.ops.wm.stl_import,
|
||||
"usda": bpy.ops.import_scene.usda,
|
||||
"dae": bpy.ops.wm.collada_import,
|
||||
"ply": bpy.ops.import_mesh.ply if bpy.app.version[0] < 4 else bpy.ops.wm.ply_import,
|
||||
"abc": bpy.ops.wm.alembic_import,
|
||||
"blend": bpy.ops.wm.append,
|
||||
}
|
||||
|
||||
|
||||
def init_scene() -> None:
|
||||
"""Resets the scene to a clean state.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# delete everything
|
||||
for obj in bpy.data.objects:
|
||||
bpy.data.objects.remove(obj, do_unlink=True)
|
||||
|
||||
# delete all the materials
|
||||
for material in bpy.data.materials:
|
||||
bpy.data.materials.remove(material, do_unlink=True)
|
||||
|
||||
# delete all the textures
|
||||
for texture in bpy.data.textures:
|
||||
bpy.data.textures.remove(texture, do_unlink=True)
|
||||
|
||||
# delete all the images
|
||||
for image in bpy.data.images:
|
||||
bpy.data.images.remove(image, do_unlink=True)
|
||||
|
||||
|
||||
def load_object(object_path: str) -> None:
|
||||
"""Loads a model with a supported file extension into the scene.
|
||||
|
||||
Args:
|
||||
object_path (str): Path to the model file.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not supported.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_extension = object_path.split(".")[-1].lower()
|
||||
if file_extension is None:
|
||||
raise ValueError(f"Unsupported file type: {object_path}")
|
||||
|
||||
if file_extension == "usdz":
|
||||
# install usdz io package
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
|
||||
bpy.ops.preferences.addon_install(filepath=usdz_package)
|
||||
# enable it
|
||||
addon_name = "io_scene_usdz"
|
||||
bpy.ops.preferences.addon_enable(module=addon_name)
|
||||
# import the usdz
|
||||
from io_scene_usdz.import_usdz import import_usdz
|
||||
|
||||
import_usdz(context, filepath=object_path, materials=True, animations=True)
|
||||
return None
|
||||
|
||||
# load from existing import functions
|
||||
import_function = IMPORT_FUNCTIONS[file_extension]
|
||||
|
||||
print(f"Loading object from {object_path}")
|
||||
if file_extension == "blend":
|
||||
import_function(directory=object_path, link=False)
|
||||
elif file_extension in {"glb", "gltf"}:
|
||||
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS', bone_heuristic='TEMPERANCE')
|
||||
else:
|
||||
import_function(filepath=object_path)
|
||||
|
||||
|
||||
def delete_invisible_objects() -> None:
|
||||
"""Deletes all invisible objects in the scene.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
to_remove = []
|
||||
for obj in bpy.context.scene.objects:
|
||||
if obj.hide_viewport or obj.hide_render:
|
||||
obj.hide_viewport = False
|
||||
obj.hide_render = False
|
||||
obj.hide_select = False
|
||||
to_remove.append(obj)
|
||||
for obj in to_remove:
|
||||
bpy.data.objects.remove(obj, do_unlink=True)
|
||||
|
||||
# Delete invisible collections
|
||||
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
|
||||
for col in invisible_collections:
|
||||
bpy.data.collections.remove(col)
|
||||
|
||||
|
||||
def scene_bbox() -> Tuple[Vector, Vector]:
|
||||
"""Returns the bounding box of the scene.
|
||||
|
||||
Taken from Shap-E rendering script
|
||||
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
||||
|
||||
Returns:
|
||||
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
||||
"""
|
||||
bbox_min = (math.inf,) * 3
|
||||
bbox_max = (-math.inf,) * 3
|
||||
found = False
|
||||
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
|
||||
for obj in scene_meshes:
|
||||
found = True
|
||||
for coord in obj.bound_box:
|
||||
coord = Vector(coord)
|
||||
coord = obj.matrix_world @ coord
|
||||
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
||||
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
||||
if not found:
|
||||
raise RuntimeError("no objects in scene to compute bounding box for")
|
||||
return Vector(bbox_min), Vector(bbox_max)
|
||||
|
||||
|
||||
def normalize_scene() -> Tuple[float, Vector]:
|
||||
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
||||
at the origin.
|
||||
|
||||
Mostly taken from the Point-E / Shap-E rendering script
|
||||
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
||||
but fix for multiple root objects: (see bug report here:
|
||||
https://github.com/openai/shap-e/pull/60).
|
||||
|
||||
Returns:
|
||||
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
|
||||
"""
|
||||
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
|
||||
if len(scene_root_objects) > 1:
|
||||
# create an empty object to be used as a parent for all root objects
|
||||
scene = bpy.data.objects.new("ParentEmpty", None)
|
||||
bpy.context.scene.collection.objects.link(scene)
|
||||
|
||||
# parent all root objects to the empty object
|
||||
for obj in scene_root_objects:
|
||||
obj.parent = scene
|
||||
else:
|
||||
scene = scene_root_objects[0]
|
||||
|
||||
bbox_min, bbox_max = scene_bbox()
|
||||
scale = 1 / max(bbox_max - bbox_min)
|
||||
scene.scale = scene.scale * scale
|
||||
|
||||
# Apply scale to matrix_world.
|
||||
bpy.context.view_layer.update()
|
||||
bbox_min, bbox_max = scene_bbox()
|
||||
offset = -(bbox_min + bbox_max) / 2
|
||||
scene.matrix_world.translation += offset
|
||||
|
||||
return scale, offset
|
||||
|
||||
|
||||
def main(arg):
|
||||
# Initialize context
|
||||
if arg.object.endswith(".blend"):
|
||||
delete_invisible_objects()
|
||||
else:
|
||||
init_scene()
|
||||
load_object(arg.object)
|
||||
print('[INFO] Scene initialized.')
|
||||
|
||||
# Normalize scene
|
||||
scale, offset = normalize_scene()
|
||||
print('[INFO] Scene normalized.')
|
||||
|
||||
# Start dumping
|
||||
depsgraph = bpy.context.evaluated_depsgraph_get()
|
||||
scene = bpy.context.scene
|
||||
output = {
|
||||
'objects': [],
|
||||
}
|
||||
|
||||
# Dumping meshes
|
||||
for obj in scene.objects:
|
||||
if obj.type != 'MESH':
|
||||
continue
|
||||
|
||||
pack = {
|
||||
"vertices": None,
|
||||
"faces": None,
|
||||
}
|
||||
|
||||
eval_obj = obj.evaluated_get(depsgraph)
|
||||
eval_mesh = eval_obj.to_mesh()
|
||||
|
||||
bm = bmesh.new()
|
||||
bm.from_mesh(eval_mesh)
|
||||
bm.transform(obj.matrix_world)
|
||||
bmesh.ops.triangulate(bm, faces=bm.faces)
|
||||
bm.to_mesh(eval_mesh)
|
||||
bm.free()
|
||||
|
||||
pack["vertices"] = np.array([
|
||||
v.co[:] for v in eval_mesh.vertices
|
||||
], dtype=np.float32) # (N, 3)
|
||||
|
||||
pack["faces"] = np.array([
|
||||
[eval_mesh.loops[i].vertex_index for i in poly.loop_indices]
|
||||
for poly in eval_mesh.polygons
|
||||
], dtype=np.int32) # (F, 3)
|
||||
|
||||
output['objects'].append(pack)
|
||||
|
||||
# Save output
|
||||
os.makedirs(os.path.dirname(arg.output_path), exist_ok=True)
|
||||
with open(arg.output_path, 'wb') as f:
|
||||
pickle.dump(output, f)
|
||||
print('[INFO] Output saved to {}.'.format(arg.output_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
||||
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
|
||||
parser.add_argument('--output_path', type=str, default='/tmp', help='The path the output will be dumped to.')
|
||||
argv = sys.argv[sys.argv.index("--") + 1:]
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
main(args)
|
||||
|
||||
485
data_toolkit/blender_script/dump_pbr.py
Executable file
485
data_toolkit/blender_script/dump_pbr.py
Executable file
@@ -0,0 +1,485 @@
|
||||
import argparse, sys, os, math, io
|
||||
from typing import *
|
||||
import bpy
|
||||
import bmesh
|
||||
from mathutils import Vector, Matrix
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pickle
|
||||
|
||||
|
||||
"""=============== BLENDER ==============="""
|
||||
|
||||
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
||||
"obj": bpy.ops.import_scene.obj if bpy.app.version[0] < 4 else bpy.ops.wm.obj_import,
|
||||
"glb": bpy.ops.import_scene.gltf,
|
||||
"gltf": bpy.ops.import_scene.gltf,
|
||||
"usd": bpy.ops.import_scene.usd,
|
||||
"fbx": bpy.ops.import_scene.fbx,
|
||||
"stl": bpy.ops.import_mesh.stl if bpy.app.version[0] < 4 else bpy.ops.wm.stl_import,
|
||||
"usda": bpy.ops.import_scene.usda,
|
||||
"dae": bpy.ops.wm.collada_import,
|
||||
"ply": bpy.ops.import_mesh.ply if bpy.app.version[0] < 4 else bpy.ops.wm.ply_import,
|
||||
"abc": bpy.ops.wm.alembic_import,
|
||||
"blend": bpy.ops.wm.append,
|
||||
}
|
||||
|
||||
|
||||
def init_scene() -> None:
|
||||
"""Resets the scene to a clean state.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# delete everything
|
||||
for obj in bpy.data.objects:
|
||||
bpy.data.objects.remove(obj, do_unlink=True)
|
||||
|
||||
# delete all the materials
|
||||
for material in bpy.data.materials:
|
||||
bpy.data.materials.remove(material, do_unlink=True)
|
||||
|
||||
# delete all the textures
|
||||
for texture in bpy.data.textures:
|
||||
bpy.data.textures.remove(texture, do_unlink=True)
|
||||
|
||||
# delete all the images
|
||||
for image in bpy.data.images:
|
||||
bpy.data.images.remove(image, do_unlink=True)
|
||||
|
||||
|
||||
def load_object(object_path: str) -> None:
|
||||
"""Loads a model with a supported file extension into the scene.
|
||||
|
||||
Args:
|
||||
object_path (str): Path to the model file.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not supported.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_extension = object_path.split(".")[-1].lower()
|
||||
if file_extension is None:
|
||||
raise ValueError(f"Unsupported file type: {object_path}")
|
||||
|
||||
if file_extension == "usdz":
|
||||
# install usdz io package
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
|
||||
bpy.ops.preferences.addon_install(filepath=usdz_package)
|
||||
# enable it
|
||||
addon_name = "io_scene_usdz"
|
||||
bpy.ops.preferences.addon_enable(module=addon_name)
|
||||
# import the usdz
|
||||
from io_scene_usdz.import_usdz import import_usdz
|
||||
|
||||
import_usdz(context, filepath=object_path, materials=True, animations=True)
|
||||
return None
|
||||
|
||||
# load from existing import functions
|
||||
import_function = IMPORT_FUNCTIONS[file_extension]
|
||||
|
||||
print(f"Loading object from {object_path}")
|
||||
if file_extension == "blend":
|
||||
import_function(directory=object_path, link=False)
|
||||
elif file_extension in {"glb", "gltf"}:
|
||||
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS', bone_heuristic='TEMPERANCE')
|
||||
else:
|
||||
import_function(filepath=object_path)
|
||||
|
||||
|
||||
def delete_invisible_objects() -> None:
|
||||
"""Deletes all invisible objects in the scene.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
to_remove = []
|
||||
for obj in bpy.context.scene.objects:
|
||||
if obj.hide_viewport or obj.hide_render:
|
||||
obj.hide_viewport = False
|
||||
obj.hide_render = False
|
||||
obj.hide_select = False
|
||||
to_remove.append(obj)
|
||||
for obj in to_remove:
|
||||
bpy.data.objects.remove(obj, do_unlink=True)
|
||||
|
||||
# Delete invisible collections
|
||||
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
|
||||
for col in invisible_collections:
|
||||
bpy.data.collections.remove(col)
|
||||
|
||||
|
||||
def scene_bbox() -> Tuple[Vector, Vector]:
|
||||
"""Returns the bounding box of the scene.
|
||||
|
||||
Taken from Shap-E rendering script
|
||||
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
||||
|
||||
Returns:
|
||||
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
||||
"""
|
||||
bbox_min = (math.inf,) * 3
|
||||
bbox_max = (-math.inf,) * 3
|
||||
found = False
|
||||
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
|
||||
for obj in scene_meshes:
|
||||
found = True
|
||||
for coord in obj.bound_box:
|
||||
coord = Vector(coord)
|
||||
coord = obj.matrix_world @ coord
|
||||
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
||||
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
||||
if not found:
|
||||
raise RuntimeError("no objects in scene to compute bounding box for")
|
||||
return Vector(bbox_min), Vector(bbox_max)
|
||||
|
||||
|
||||
def normalize_scene() -> Tuple[float, Vector]:
|
||||
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
||||
at the origin.
|
||||
|
||||
Mostly taken from the Point-E / Shap-E rendering script
|
||||
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
||||
but fix for multiple root objects: (see bug report here:
|
||||
https://github.com/openai/shap-e/pull/60).
|
||||
|
||||
Returns:
|
||||
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
|
||||
"""
|
||||
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
|
||||
if len(scene_root_objects) > 1:
|
||||
# create an empty object to be used as a parent for all root objects
|
||||
scene = bpy.data.objects.new("ParentEmpty", None)
|
||||
bpy.context.scene.collection.objects.link(scene)
|
||||
|
||||
# parent all root objects to the empty object
|
||||
for obj in scene_root_objects:
|
||||
obj.parent = scene
|
||||
else:
|
||||
scene = scene_root_objects[0]
|
||||
|
||||
bbox_min, bbox_max = scene_bbox()
|
||||
scale = 1 / max(bbox_max - bbox_min)
|
||||
scene.scale = scene.scale * scale
|
||||
|
||||
# Apply scale to matrix_world.
|
||||
bpy.context.view_layer.update()
|
||||
bbox_min, bbox_max = scene_bbox()
|
||||
offset = -(bbox_min + bbox_max) / 2
|
||||
scene.matrix_world.translation += offset
|
||||
|
||||
return scale, offset
|
||||
|
||||
|
||||
# =============== NODE TREE PARSING ===============
|
||||
|
||||
def extract_image(tex_node, channels):
|
||||
image = tex_node.image
|
||||
pixels = np.array(image.pixels[:])
|
||||
data = pixels.reshape(image.size[1], image.size[0], -1)
|
||||
data = data[..., channels]
|
||||
|
||||
if data.dtype != np.uint8:
|
||||
data = np.clip(data, 0.0, 1.0)
|
||||
data = (data * 255).astype(np.uint8)
|
||||
|
||||
if len(data.shape) == 2: # Single channel
|
||||
pil_image = Image.fromarray(data, mode='L')
|
||||
elif data.shape[2] == 3:
|
||||
pil_image = Image.fromarray(data, mode='RGB')
|
||||
elif data.shape[2] == 4:
|
||||
pil_image = Image.fromarray(data, mode='RGBA')
|
||||
else:
|
||||
raise ValueError("Unsupported channel shape for image")
|
||||
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='PNG')
|
||||
png_bytes = buffer.getvalue()
|
||||
|
||||
return {
|
||||
'image': png_bytes,
|
||||
'interpolation': tex_node.interpolation,
|
||||
'extension': tex_node.extension,
|
||||
}
|
||||
|
||||
|
||||
def try_extract_image(link, expected_channel='RGB'):
|
||||
"""
|
||||
Tries to extract an image from a texture node link.
|
||||
Supported sub tree modes:
|
||||
- RGB:
|
||||
TEX_IMAGE ->
|
||||
- R, G, B:
|
||||
TEX_IMAGE -> SEPARATE_COLOR ->
|
||||
- A:
|
||||
TEX_IMAGE ->
|
||||
"""
|
||||
assert expected_channel in ['RGB', 'R', 'G', 'B', 'A'], "Unsupported channel"
|
||||
|
||||
if expected_channel == 'RGB':
|
||||
assert link.from_node.type == 'TEX_IMAGE', "Material is not supported"
|
||||
assert link.from_socket.name == 'Color', "Material is not supported"
|
||||
tex_node = link.from_node
|
||||
return extract_image(tex_node, [0, 1, 2])
|
||||
|
||||
if expected_channel in ['R', 'G', 'B']:
|
||||
socket_name = {
|
||||
'R': 'Red',
|
||||
'G': 'Green',
|
||||
'B': 'Blue',
|
||||
}[expected_channel]
|
||||
assert link.from_node.type == 'SEPARATE_COLOR' and link.from_node.mode == 'RGB', \
|
||||
f"Material is not supported, {link.from_node.type}, {link.from_node.mode}"
|
||||
assert link.from_socket.name == socket_name, "Material is not supported"
|
||||
sep_node = link.from_node
|
||||
assert sep_node.inputs[0].is_linked and sep_node.inputs[0].links[0].from_node.type == 'TEX_IMAGE', \
|
||||
"Material is not supported"
|
||||
assert sep_node.inputs[0].links[0].from_socket.name == 'Color', "Material is not supported"
|
||||
tex_node = sep_node.inputs[0].links[0].from_node
|
||||
channel_index = {
|
||||
'R': 0,
|
||||
'G': 1,
|
||||
'B': 2,
|
||||
}[expected_channel]
|
||||
return extract_image(tex_node, channel_index)
|
||||
|
||||
if expected_channel == 'A':
|
||||
assert link.from_node.type == 'TEX_IMAGE', "Material is not supported"
|
||||
assert link.from_socket.name == 'Alpha', "Material is not supported"
|
||||
tex_node = link.from_node
|
||||
return extract_image(tex_node, 3)
|
||||
|
||||
|
||||
def try_extract_factor(link, mode='color'):
|
||||
"""
|
||||
Tries to extract a factor from a math node link.
|
||||
Supported sub tree modes:
|
||||
- color:
|
||||
ANY -> MIX(MULTIPLY) ->
|
||||
- scalar:
|
||||
ANY -> MATH(MULTIPLY) ->
|
||||
"""
|
||||
assert mode in ['color','scalar'], "Unsupported mode"
|
||||
|
||||
if mode == 'color':
|
||||
if link.from_node.type == 'MIX':
|
||||
mix_node = link.from_node
|
||||
assert mix_node.data_type == 'RGBA' and mix_node.blend_type == 'MULTIPLY', f"Material is not supported, {mix_node.data_type}, {mix_node.blend_type}"
|
||||
assert not mix_node.inputs['Factor'].is_linked and mix_node.inputs['Factor'].default_value == 1.0, \
|
||||
"Material is not supported"
|
||||
if mix_node.inputs['A'].is_linked:
|
||||
assert not mix_node.inputs['B'].is_linked, "Material is not supported"
|
||||
return (list(mix_node.inputs['B'].default_value)[:3], mix_node.inputs['A'].links[0])
|
||||
else:
|
||||
assert not mix_node.inputs['A'].is_linked, "Material is not supported"
|
||||
assert mix_node.inputs['B'].is_linked, "Material is not supported"
|
||||
return (list(mix_node.inputs['A'].default_value)[:3], mix_node.inputs['B'].links[0])
|
||||
return ([1.0, 1.0, 1.0], link)
|
||||
|
||||
if mode =='scalar':
|
||||
if link.from_node.type == 'MATH':
|
||||
math_node = link.from_node
|
||||
assert math_node.operation == 'MULTIPLY', "Material is not supported"
|
||||
assert math_node.inputs[0].is_linked, "Material is not supported"
|
||||
assert not math_node.inputs[1].is_linked, "Material is not supported"
|
||||
return (math_node.inputs[1].default_value, math_node.inputs[0].links[0])
|
||||
return (1.0, link)
|
||||
|
||||
|
||||
def try_extract_image_with_factor(link, expected_channel='RGB'):
|
||||
"""
|
||||
Tries to extract an image and a factor from a texture node link.
|
||||
"""
|
||||
factor, link = try_extract_factor(link, 'color' if expected_channel in ['RGB'] else 'scalar')
|
||||
image = try_extract_image(link, expected_channel)
|
||||
return (factor, image)
|
||||
|
||||
|
||||
def main(arg):
|
||||
# Initialize context
|
||||
if arg.object.endswith(".blend"):
|
||||
delete_invisible_objects()
|
||||
else:
|
||||
init_scene()
|
||||
load_object(arg.object)
|
||||
print('[INFO] Scene initialized.')
|
||||
|
||||
# Normalize scene
|
||||
scale, offset = normalize_scene()
|
||||
print('[INFO] Scene normalized.')
|
||||
|
||||
# Start dumping
|
||||
depsgraph = bpy.context.evaluated_depsgraph_get()
|
||||
scene = bpy.context.scene
|
||||
output = {
|
||||
'materials': [],
|
||||
'objects': [],
|
||||
}
|
||||
|
||||
# Dumping materials
|
||||
for mat in bpy.data.materials:
|
||||
assert mat.use_nodes == True, "Material is not supported"
|
||||
|
||||
pack = {
|
||||
"baseColorFactor": [1.0, 1.0, 1.0],
|
||||
"alphaFactor": 1.0,
|
||||
"metallicFactor": 1.0,
|
||||
"roughnessFactor": 1.0,
|
||||
"alphaMode": "OPAQUE",
|
||||
"alphaCutoff": 0.5,
|
||||
"baseColorTexture": None,
|
||||
"alphaTexture": None,
|
||||
"metallicTexture": None,
|
||||
"roughnessTexture": None,
|
||||
}
|
||||
|
||||
try:
|
||||
principled_node = mat.node_tree.nodes.get('Principled BSDF')
|
||||
assert principled_node is not None, "Material is not supported"
|
||||
|
||||
# Handle base color
|
||||
if not principled_node.inputs['Base Color'].is_linked:
|
||||
pack["baseColorFactor"] = list(principled_node.inputs['Base Color'].default_value)
|
||||
else:
|
||||
link = principled_node.inputs['Base Color'].links[0]
|
||||
if link.from_node.type == 'RGB':
|
||||
pack["baseColorFactor"] = list(link.from_node.outputs[0].default_value)
|
||||
else:
|
||||
factor, image = try_extract_image_with_factor(link, 'RGB')
|
||||
pack["baseColorFactor"] = factor
|
||||
pack["baseColorTexture"] = image
|
||||
|
||||
# Handle alpha
|
||||
if not principled_node.inputs['Alpha'].is_linked:
|
||||
pack["alphaFactor"] = principled_node.inputs['Alpha'].default_value
|
||||
if pack["alphaFactor"] < 1.0:
|
||||
pack["alphaMode"] = "BLEND"
|
||||
else:
|
||||
link = principled_node.inputs['Alpha'].links[0]
|
||||
node = link.from_node
|
||||
if node.type == 'VALUE':
|
||||
pack["alphaFactor"] = node.outputs[0].default_value
|
||||
if pack["alphaFactor"] < 1.0:
|
||||
pack["alphaMode"] = "BLEND"
|
||||
else:
|
||||
pack["alphaMode"] = "BLEND"
|
||||
if node.type == 'MATH':
|
||||
if node.operation == 'ROUND':
|
||||
assert node.inputs[0].is_linked, "Material is not supported"
|
||||
pack["alphaMode"] = "MASK"
|
||||
link = node.inputs[0].links[0]
|
||||
elif node.operation == 'SUBTRACT':
|
||||
assert node.inputs[0].default_value == 1.0 and \
|
||||
node.inputs[1].is_linked and \
|
||||
node.inputs[1].links[0].from_node.type == 'MATH' and \
|
||||
node.inputs[1].links[0].from_node.operation == 'LESS_THAN', \
|
||||
"Material is not supported"
|
||||
assert node.inputs[1].links[0].from_node.inputs[0].is_linked, "Material is not supported"
|
||||
pack["alphaMode"] = "MASK"
|
||||
pack["alphaCutoff"] = node.inputs[1].links[0].from_node.inputs[1].default_value
|
||||
link = node.inputs[1].links[0].from_node.inputs[0].links[0]
|
||||
factor, image = try_extract_image_with_factor(link, 'A')
|
||||
pack["alphaFactor"] = factor
|
||||
pack["alphaTexture"] = image
|
||||
|
||||
# Handle metallic
|
||||
if not principled_node.inputs['Metallic'].is_linked:
|
||||
pack["metallicFactor"] = principled_node.inputs['Metallic'].default_value
|
||||
else:
|
||||
link = principled_node.inputs['Metallic'].links[0]
|
||||
node = link.from_node
|
||||
if node.type == 'VALUE':
|
||||
pack["metallicFactor"] = node.outputs[0].default_value
|
||||
else:
|
||||
factor, image = try_extract_image_with_factor(link, 'B')
|
||||
pack["metallicFactor"] = factor
|
||||
pack["metallicTexture"] = image
|
||||
|
||||
# Handle roughness
|
||||
if not principled_node.inputs['Roughness'].is_linked:
|
||||
pack["roughnessFactor"] = principled_node.inputs['Roughness'].default_value
|
||||
else:
|
||||
link = principled_node.inputs['Roughness'].links[0]
|
||||
node = link.from_node
|
||||
if node.type == 'VALUE':
|
||||
pack["roughnessFactor"] = node.outputs[0].default_value
|
||||
else:
|
||||
factor, image = try_extract_image_with_factor(link, 'G')
|
||||
pack["roughnessFactor"] = factor
|
||||
pack["roughnessTexture"] = image
|
||||
|
||||
output['materials'].append(pack)
|
||||
except:
|
||||
with open(arg.output_path + '_error.txt', 'w') as f:
|
||||
f.write(str([[n.name] for n in mat.node_tree.nodes]))
|
||||
raise RuntimeError("Material is not supported")
|
||||
|
||||
# Dumping meshes
|
||||
for obj in scene.objects:
|
||||
if obj.type != 'MESH':
|
||||
continue
|
||||
|
||||
pack = {
|
||||
"vertices": None,
|
||||
"faces": None,
|
||||
"uvs": None,
|
||||
"matIDs": None,
|
||||
}
|
||||
|
||||
eval_obj = obj.evaluated_get(depsgraph)
|
||||
eval_mesh = eval_obj.to_mesh()
|
||||
|
||||
bm = bmesh.new()
|
||||
bm.from_mesh(eval_mesh)
|
||||
bm.transform(obj.matrix_world)
|
||||
bmesh.ops.triangulate(bm, faces=bm.faces)
|
||||
bm.to_mesh(eval_mesh)
|
||||
bm.free()
|
||||
|
||||
pack["vertices"] = np.array([
|
||||
v.co[:] for v in eval_mesh.vertices
|
||||
], dtype=np.float32) # (N, 3)
|
||||
|
||||
pack["faces"] = np.array([
|
||||
[eval_mesh.loops[i].vertex_index for i in poly.loop_indices]
|
||||
for poly in eval_mesh.polygons
|
||||
], dtype=np.int32) # (F, 3)
|
||||
|
||||
pack["normals"] = np.array([
|
||||
[eval_mesh.loops[i].normal for i in poly.loop_indices]
|
||||
for poly in eval_mesh.polygons
|
||||
], dtype=np.float32) # (F, 3, 3)
|
||||
|
||||
if eval_mesh.uv_layers.active is not None:
|
||||
pack["uvs"] = np.array([
|
||||
[eval_mesh.uv_layers.active.data[i].uv for i in poly.loop_indices]
|
||||
for poly in eval_mesh.polygons
|
||||
], dtype=np.float32) # (F, 3, 2)
|
||||
|
||||
pack["mat_ids"] = np.array([
|
||||
bpy.data.materials.find(obj.material_slots[poly.material_index].name)
|
||||
if len(obj.material_slots) > 0 and obj.material_slots[poly.material_index].material is not None else -1
|
||||
for poly in eval_mesh.polygons
|
||||
], dtype=np.int32)
|
||||
|
||||
output['objects'].append(pack)
|
||||
|
||||
# Save output
|
||||
os.makedirs(os.path.dirname(arg.output_path), exist_ok=True)
|
||||
with open(arg.output_path, 'wb') as f:
|
||||
pickle.dump(output, f)
|
||||
print('[INFO] Output saved to {}.'.format(arg.output_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
||||
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
|
||||
parser.add_argument('--output_path', type=str, default='/tmp', help='The path the output will be dumped to.')
|
||||
argv = sys.argv[sys.argv.index("--") + 1:]
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
main(args)
|
||||
|
||||
6
data_toolkit/blender_script/install_pillow.py
Executable file
6
data_toolkit/blender_script/install_pillow.py
Executable file
@@ -0,0 +1,6 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import ensurepip
|
||||
|
||||
ensurepip.bootstrap()
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "Pillow"])
|
||||
BIN
data_toolkit/blender_script/io_scene_usdz.zip
Executable file
BIN
data_toolkit/blender_script/io_scene_usdz.zip
Executable file
Binary file not shown.
437
data_toolkit/blender_script/render_cond.py
Normal file
437
data_toolkit/blender_script/render_cond.py
Normal file
@@ -0,0 +1,437 @@
|
||||
import argparse, sys, os, math, re, glob
|
||||
from typing import *
|
||||
import bpy
|
||||
from mathutils import Vector, Matrix
|
||||
import numpy as np
|
||||
import json
|
||||
import glob
|
||||
|
||||
|
||||
"""=============== BLENDER ==============="""
|
||||
|
||||
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
||||
"obj": bpy.ops.import_scene.obj,
|
||||
"glb": bpy.ops.import_scene.gltf,
|
||||
"gltf": bpy.ops.import_scene.gltf,
|
||||
"usd": bpy.ops.import_scene.usd,
|
||||
"fbx": bpy.ops.import_scene.fbx,
|
||||
"stl": bpy.ops.import_mesh.stl,
|
||||
"usda": bpy.ops.import_scene.usda,
|
||||
"dae": bpy.ops.wm.collada_import,
|
||||
"ply": bpy.ops.import_mesh.ply,
|
||||
"abc": bpy.ops.wm.alembic_import,
|
||||
"blend": bpy.ops.wm.append,
|
||||
}
|
||||
|
||||
EXT = {
|
||||
'PNG': 'png',
|
||||
'JPEG': 'jpg',
|
||||
'OPEN_EXR': 'exr',
|
||||
'TIFF': 'tiff',
|
||||
'BMP': 'bmp',
|
||||
'HDR': 'hdr',
|
||||
'TARGA': 'tga'
|
||||
}
|
||||
|
||||
|
||||
def init_render(engine='CYCLES', resolution=512):
|
||||
bpy.context.scene.render.engine = engine
|
||||
bpy.context.scene.render.resolution_x = resolution
|
||||
bpy.context.scene.render.resolution_y = resolution
|
||||
bpy.context.scene.render.resolution_percentage = 100
|
||||
bpy.context.scene.render.image_settings.file_format = 'PNG'
|
||||
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
|
||||
bpy.context.scene.render.film_transparent = True
|
||||
|
||||
bpy.context.scene.cycles.device = 'GPU'
|
||||
bpy.context.scene.cycles.samples = 32
|
||||
bpy.context.scene.cycles.filter_type = 'BOX'
|
||||
bpy.context.scene.cycles.filter_width = 1
|
||||
bpy.context.scene.cycles.diffuse_bounces = 1
|
||||
bpy.context.scene.cycles.glossy_bounces = 1
|
||||
bpy.context.scene.cycles.transparent_max_bounces = 3
|
||||
bpy.context.scene.cycles.transmission_bounces = 3
|
||||
bpy.context.scene.cycles.use_denoising = True
|
||||
|
||||
bpy.context.preferences.addons['cycles'].preferences.get_devices()
|
||||
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
||||
|
||||
|
||||
def init_scene() -> None:
|
||||
"""Resets the scene to a clean state.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# delete everything
|
||||
for obj in bpy.data.objects:
|
||||
bpy.data.objects.remove(obj, do_unlink=True)
|
||||
|
||||
# delete all the materials
|
||||
for material in bpy.data.materials:
|
||||
bpy.data.materials.remove(material, do_unlink=True)
|
||||
|
||||
# delete all the textures
|
||||
for texture in bpy.data.textures:
|
||||
bpy.data.textures.remove(texture, do_unlink=True)
|
||||
|
||||
# delete all the images
|
||||
for image in bpy.data.images:
|
||||
bpy.data.images.remove(image, do_unlink=True)
|
||||
|
||||
|
||||
def init_camera():
|
||||
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
|
||||
bpy.context.collection.objects.link(cam)
|
||||
bpy.context.scene.camera = cam
|
||||
cam.data.sensor_height = cam.data.sensor_width = 32
|
||||
cam_constraint = cam.constraints.new(type='TRACK_TO')
|
||||
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
|
||||
cam_constraint.up_axis = 'UP_Y'
|
||||
cam_empty = bpy.data.objects.new("Empty", None)
|
||||
cam_empty.location = (0, 0, 0)
|
||||
bpy.context.scene.collection.objects.link(cam_empty)
|
||||
cam_constraint.target = cam_empty
|
||||
return cam
|
||||
|
||||
|
||||
def init_uniform_lighting():
|
||||
# Clear existing lights
|
||||
bpy.ops.object.select_all(action="DESELECT")
|
||||
bpy.ops.object.select_by_type(type="LIGHT")
|
||||
bpy.ops.object.delete()
|
||||
|
||||
# Create environment light
|
||||
if bpy.context.scene.world is None:
|
||||
world = bpy.data.worlds.new("World")
|
||||
bpy.context.scene.world = world
|
||||
else:
|
||||
world = bpy.context.scene.world
|
||||
|
||||
# Enabling nodes
|
||||
world.use_nodes = True
|
||||
node_tree = world.node_tree
|
||||
nodes = node_tree.nodes
|
||||
links = node_tree.links
|
||||
|
||||
# Remove default nodes
|
||||
for node in nodes:
|
||||
nodes.remove(node)
|
||||
|
||||
# Create background node
|
||||
bg_node = nodes.new(type="ShaderNodeBackground")
|
||||
bg_node.inputs["Color"].default_value = (1.0, 1.0, 1.0, 1.0)
|
||||
bg_node.inputs["Strength"].default_value = 1.0
|
||||
output_node = nodes.new(type="ShaderNodeOutputWorld")
|
||||
links.new(bg_node.outputs["Background"], output_node.inputs["Surface"])
|
||||
|
||||
|
||||
def init_random_lighting(camera_dir: np.ndarray) -> None:
|
||||
# Clear existing lights
|
||||
bpy.ops.object.select_all(action="DESELECT")
|
||||
bpy.ops.object.select_by_type(type="LIGHT")
|
||||
bpy.ops.object.delete()
|
||||
|
||||
# Create environment light
|
||||
if bpy.context.scene.world is None:
|
||||
world = bpy.data.worlds.new("World")
|
||||
bpy.context.scene.world = world
|
||||
else:
|
||||
world = bpy.context.scene.world
|
||||
|
||||
# Enabling nodes
|
||||
world.use_nodes = True
|
||||
node_tree = world.node_tree
|
||||
nodes = node_tree.nodes
|
||||
links = node_tree.links
|
||||
|
||||
# Remove default nodes
|
||||
for node in nodes:
|
||||
nodes.remove(node)
|
||||
|
||||
# Random place lights
|
||||
num_lights = np.random.randint(1, 4)
|
||||
total_strength = 1.5
|
||||
for i in range(num_lights):
|
||||
new_light = bpy.data.objects.new(f"Light_{i}", bpy.data.lights.new(f"Light_{i}", type="POINT"))
|
||||
bpy.context.collection.objects.link(new_light)
|
||||
|
||||
new_light_distance = 1 / np.random.uniform(1/100, 1/10)
|
||||
new_light_dir = np.random.randn(3)
|
||||
new_light_dir[2] += 0.6
|
||||
new_light_dir = new_light_dir / np.linalg.norm(new_light_dir)
|
||||
new_light_location = new_light_dir * new_light_distance
|
||||
new_light_camera_strength_ratio = max(np.sum(camera_dir * new_light_dir) * 0.5 + 0.5, 0)
|
||||
new_light_max_energy = total_strength / (np.sum(camera_dir * new_light_dir) * 0.45 + 0.55)
|
||||
new_light_strength = np.sqrt(np.random.uniform(0.01, 1)) * new_light_max_energy
|
||||
new_light_camera_strength = new_light_camera_strength_ratio * new_light_strength
|
||||
total_strength -= new_light_camera_strength
|
||||
|
||||
new_light.location = (new_light_location[0], new_light_location[1], new_light_location[2])
|
||||
new_light.data.color = (1.0, 1.0, 1.0)
|
||||
new_light.data.energy = new_light_strength * new_light_distance**2 * 31.4
|
||||
new_light.data.shadow_soft_size = np.random.uniform(0.1, 0.1 * new_light_distance)
|
||||
|
||||
# Create background node
|
||||
bg_node = nodes.new(type="ShaderNodeBackground")
|
||||
bg_node.inputs["Color"].default_value = (1.0, 1.0, 1.0, 1.0)
|
||||
bg_node.inputs["Strength"].default_value = total_strength
|
||||
output_node = nodes.new(type="ShaderNodeOutputWorld")
|
||||
links.new(bg_node.outputs["Background"], output_node.inputs["Surface"])
|
||||
|
||||
|
||||
def load_object(object_path: str) -> None:
|
||||
"""Loads a model with a supported file extension into the scene.
|
||||
|
||||
Args:
|
||||
object_path (str): Path to the model file.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not supported.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_extension = object_path.split(".")[-1].lower()
|
||||
if file_extension is None:
|
||||
raise ValueError(f"Unsupported file type: {object_path}")
|
||||
|
||||
if file_extension == "usdz":
|
||||
# install usdz io package
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
|
||||
bpy.ops.preferences.addon_install(filepath=usdz_package)
|
||||
# enable it
|
||||
addon_name = "io_scene_usdz"
|
||||
bpy.ops.preferences.addon_enable(module=addon_name)
|
||||
# import the usdz
|
||||
from io_scene_usdz.import_usdz import import_usdz
|
||||
|
||||
import_usdz(context, filepath=object_path, materials=True, animations=True)
|
||||
return None
|
||||
|
||||
# load from existing import functions
|
||||
import_function = IMPORT_FUNCTIONS[file_extension]
|
||||
|
||||
print(f"Loading object from {object_path}")
|
||||
if file_extension == "blend":
|
||||
import_function(directory=object_path, link=False)
|
||||
elif file_extension in {"glb", "gltf"}:
|
||||
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
|
||||
else:
|
||||
import_function(filepath=object_path)
|
||||
|
||||
|
||||
def delete_invisible_objects() -> None:
|
||||
"""Deletes all invisible objects in the scene.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# bpy.ops.object.mode_set(mode="OBJECT")
|
||||
bpy.ops.object.select_all(action="DESELECT")
|
||||
for obj in bpy.context.scene.objects:
|
||||
if obj.hide_viewport or obj.hide_render:
|
||||
obj.hide_viewport = False
|
||||
obj.hide_render = False
|
||||
obj.hide_select = False
|
||||
obj.select_set(True)
|
||||
bpy.ops.object.delete()
|
||||
|
||||
# Delete invisible collections
|
||||
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
|
||||
for col in invisible_collections:
|
||||
bpy.data.collections.remove(col)
|
||||
|
||||
|
||||
def unhide_all_objects() -> None:
|
||||
"""Unhides all objects in the scene.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for obj in bpy.context.scene.objects:
|
||||
obj.hide_set(False)
|
||||
|
||||
|
||||
def convert_to_meshes() -> None:
|
||||
"""Converts all objects in the scene to meshes.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
bpy.ops.object.select_all(action="DESELECT")
|
||||
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
|
||||
for obj in bpy.context.scene.objects:
|
||||
obj.select_set(True)
|
||||
bpy.ops.object.convert(target="MESH")
|
||||
|
||||
|
||||
def triangulate_meshes() -> None:
|
||||
"""Triangulates all meshes in the scene.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
bpy.ops.object.select_all(action="DESELECT")
|
||||
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
||||
bpy.context.view_layer.objects.active = objs[0]
|
||||
for obj in objs:
|
||||
obj.select_set(True)
|
||||
bpy.ops.object.mode_set(mode="EDIT")
|
||||
bpy.ops.mesh.reveal()
|
||||
bpy.ops.mesh.select_all(action="SELECT")
|
||||
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
|
||||
bpy.ops.object.mode_set(mode="OBJECT")
|
||||
bpy.ops.object.select_all(action="DESELECT")
|
||||
|
||||
|
||||
def scene_bbox() -> Tuple[Vector, Vector]:
|
||||
"""Returns the bounding box of the scene.
|
||||
|
||||
Taken from Shap-E rendering script
|
||||
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
||||
|
||||
Returns:
|
||||
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
||||
"""
|
||||
bbox_min = (math.inf,) * 3
|
||||
bbox_max = (-math.inf,) * 3
|
||||
found = False
|
||||
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
|
||||
for obj in scene_meshes:
|
||||
found = True
|
||||
for coord in obj.bound_box:
|
||||
coord = Vector(coord)
|
||||
coord = obj.matrix_world @ coord
|
||||
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
||||
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
||||
if not found:
|
||||
raise RuntimeError("no objects in scene to compute bounding box for")
|
||||
return Vector(bbox_min), Vector(bbox_max)
|
||||
|
||||
|
||||
def normalize_scene() -> Tuple[float, Vector]:
|
||||
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
||||
at the origin.
|
||||
|
||||
Mostly taken from the Point-E / Shap-E rendering script
|
||||
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
||||
but fix for multiple root objects: (see bug report here:
|
||||
https://github.com/openai/shap-e/pull/60).
|
||||
|
||||
Returns:
|
||||
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
|
||||
"""
|
||||
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
|
||||
if len(scene_root_objects) > 1:
|
||||
# create an empty object to be used as a parent for all root objects
|
||||
scene = bpy.data.objects.new("ParentEmpty", None)
|
||||
bpy.context.scene.collection.objects.link(scene)
|
||||
|
||||
# parent all root objects to the empty object
|
||||
for obj in scene_root_objects:
|
||||
obj.parent = scene
|
||||
else:
|
||||
scene = scene_root_objects[0]
|
||||
|
||||
bbox_min, bbox_max = scene_bbox()
|
||||
scale = 1 / max(bbox_max - bbox_min)
|
||||
scene.scale = scene.scale * scale
|
||||
|
||||
# Apply scale to matrix_world.
|
||||
bpy.context.view_layer.update()
|
||||
bbox_min, bbox_max = scene_bbox()
|
||||
offset = -(bbox_min + bbox_max) / 2
|
||||
scene.matrix_world.translation += offset
|
||||
bpy.ops.object.select_all(action="DESELECT")
|
||||
|
||||
return scale, offset
|
||||
|
||||
|
||||
def get_transform_matrix(obj: bpy.types.Object) -> list:
|
||||
pos, rt, _ = obj.matrix_world.decompose()
|
||||
rt = rt.to_matrix()
|
||||
matrix = []
|
||||
for ii in range(3):
|
||||
a = []
|
||||
for jj in range(3):
|
||||
a.append(rt[ii][jj])
|
||||
a.append(pos[ii])
|
||||
matrix.append(a)
|
||||
matrix.append([0, 0, 0, 1])
|
||||
return matrix
|
||||
|
||||
|
||||
def main(arg):
|
||||
if arg.object.endswith(".blend"):
|
||||
delete_invisible_objects()
|
||||
else:
|
||||
init_scene()
|
||||
load_object(arg.object)
|
||||
print('[INFO] Scene initialized.')
|
||||
|
||||
# normalize scene
|
||||
scale, offset = normalize_scene()
|
||||
print('[INFO] Scene normalized.')
|
||||
|
||||
# Initialize camera and lighting
|
||||
cam = init_camera()
|
||||
init_uniform_lighting()
|
||||
print('[INFO] Camera and lighting initialized.')
|
||||
|
||||
# ============= Render conditional views =============
|
||||
init_render(engine=arg.engine, resolution=arg.cond_resolution)
|
||||
# Create a list of views
|
||||
to_export = {
|
||||
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
||||
"scale": scale,
|
||||
"offset": [offset.x, offset.y, offset.z],
|
||||
"frames": []
|
||||
}
|
||||
views = json.loads(arg.cond_views)
|
||||
for i, view in enumerate(views):
|
||||
cam_dir = np.array([
|
||||
np.cos(view['yaw']) * np.cos(view['pitch']),
|
||||
np.sin(view['yaw']) * np.cos(view['pitch']),
|
||||
np.sin(view['pitch'])
|
||||
])
|
||||
init_random_lighting(cam_dir)
|
||||
cam.location = (
|
||||
view['radius'] * cam_dir[0],
|
||||
view['radius'] * cam_dir[1],
|
||||
view['radius'] * cam_dir[2]
|
||||
)
|
||||
cam.data.lens = 16 / np.tan(view['fov'] / 2)
|
||||
|
||||
bpy.context.scene.render.filepath = os.path.join(arg.cond_output_folder, f'{i:03d}.png')
|
||||
|
||||
# Render the scene
|
||||
bpy.ops.render.render(write_still=True)
|
||||
bpy.context.view_layer.update()
|
||||
|
||||
# Save camera parameters
|
||||
metadata = {
|
||||
"file_path": f'{i:03d}.png',
|
||||
"camera_angle_x": view['fov'],
|
||||
"transform_matrix": get_transform_matrix(cam)
|
||||
}
|
||||
to_export["frames"].append(metadata)
|
||||
|
||||
# Save the camera parameters
|
||||
with open(os.path.join(arg.cond_output_folder, 'transforms.json'), 'w') as f:
|
||||
json.dump(to_export, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
||||
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
|
||||
parser.add_argument('--cond_views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
|
||||
parser.add_argument('--cond_output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
|
||||
parser.add_argument('--cond_resolution', type=int, default=1024, help='Resolution of the conditional images.')
|
||||
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
|
||||
argv = sys.argv[sys.argv.index("--") + 1:]
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
main(args)
|
||||
|
||||
219
data_toolkit/build_metadata.py
Executable file
219
data_toolkit/build_metadata.py
Executable file
@@ -0,0 +1,219 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
def update_metadata(path, opt):
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
timestamp = str(int(time.time()))
|
||||
os.makedirs(os.path.join(path, 'merged_records'), exist_ok=True)
|
||||
os.makedirs(os.path.join(path, 'new_records'), exist_ok=True)
|
||||
if opt.from_merged_records:
|
||||
df_files = [f for f in os.listdir(os.path.join(path, 'merged_records')) if f.endswith('.csv')]
|
||||
df_files = [f for f in df_files if int(f.split('_')[0]) >= opt.record_start]
|
||||
else:
|
||||
df_files = [f for f in os.listdir(os.path.join(path, 'new_records')) if f.startswith('part_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(path, 'new_records', f)))
|
||||
except Exception as e:
|
||||
print(f"Failed to read {f}: {e}")
|
||||
if len(df_parts) > 0:
|
||||
if os.path.exists(os.path.join(path, 'metadata.csv')):
|
||||
metadata = pd.read_csv(os.path.join(path, 'metadata.csv'))
|
||||
else:
|
||||
columns = df_parts[0].columns
|
||||
metadata = pd.DataFrame(columns=columns)
|
||||
metadata.set_index('sha256', inplace=True)
|
||||
for df_part in df_parts:
|
||||
if 'sha256' in df_part.columns:
|
||||
df_part.set_index('sha256', inplace=True)
|
||||
metadata = df_part.combine_first(metadata)
|
||||
metadata.to_csv(os.path.join(path, 'metadata.csv'))
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(path, 'new_records', f), os.path.join(path, 'merged_records', f'{timestamp}_{f}'))
|
||||
return metadata
|
||||
else:
|
||||
if os.path.exists(os.path.join(path, 'metadata.csv')):
|
||||
return pd.read_csv(os.path.join(path, 'metadata.csv'))
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--download_root', type=str, default=None,
|
||||
help='Directory to save the downloaded files')
|
||||
parser.add_argument('--thumbnail_root', type=str, default=None,
|
||||
help='Directory to save the thumbnail files')
|
||||
parser.add_argument('--render_cond_root', type=str, default=None,
|
||||
help='Directory to save the render condition files')
|
||||
parser.add_argument('--mesh_dump_root', type=str, default=None,
|
||||
help='Directory to save the mesh files')
|
||||
parser.add_argument('--pbr_dump_root', type=str, default=None,
|
||||
help='Directory to save the pbr files')
|
||||
parser.add_argument('--dual_grid_root', type=str, default=None,
|
||||
help='Directory to save the dual grid files')
|
||||
parser.add_argument('--pbr_voxel_root', type=str, default=None,
|
||||
help='Directory to save the pbr voxel files')
|
||||
parser.add_argument('--ss_latent_root', type=str, default=None,
|
||||
help='Directory to save the sparse structure latent files')
|
||||
parser.add_argument('--shape_latent_root', type=str, default=None,
|
||||
help='Directory to save the shape latent files')
|
||||
parser.add_argument('--pbr_latent_root', type=str, default=None,
|
||||
help='Directory to save the pbr latent files')
|
||||
parser.add_argument('--field', type=str, default='all',
|
||||
help='Fields to process, separated by commas')
|
||||
parser.add_argument('--from_file', action='store_true',
|
||||
help='Build metadata from file instead of from records of processings.' +
|
||||
'Useful when some processing fail to generate records but file already exists.')
|
||||
parser.add_argument('--from_merged_records', action='store_true',
|
||||
help='Build metadata from merged records')
|
||||
parser.add_argument('--record_start', type=int)
|
||||
parser.add_argument('--rebuild', action='store_true',
|
||||
help='Rebuild metadata from scratch, ignore existing metadata.')
|
||||
dataset_utils.add_args(parser)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
opt.download_root = opt.download_root or opt.root
|
||||
opt.thumbnail_root = opt.thumbnail_root or opt.root
|
||||
opt.render_cond_root = opt.render_cond_root or opt.root
|
||||
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
|
||||
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
|
||||
opt.dual_grid_root = opt.dual_grid_root or opt.root
|
||||
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
|
||||
opt.ss_latent_root = opt.ss_latent_root or opt.root
|
||||
opt.shape_latent_root = opt.shape_latent_root or opt.root
|
||||
opt.pbr_latent_root = opt.pbr_latent_root or opt.root
|
||||
|
||||
os.makedirs(opt.root, exist_ok=True)
|
||||
|
||||
opt.field = opt.field.split(',')
|
||||
|
||||
# get file list
|
||||
if os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
print('Loading previous metadata...')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv'))
|
||||
else:
|
||||
metadata = dataset_utils.get_metadata(**opt)
|
||||
metadata.to_csv(os.path.join(opt.root, 'metadata.csv'), index=False)
|
||||
|
||||
# merge downloaded
|
||||
downloaded_metadata = update_metadata(os.path.join(opt.download_root, 'raw'), opt)
|
||||
|
||||
# merge thumbnails
|
||||
thumbnail_metadata = update_metadata(os.path.join(opt.thumbnail_root, 'thumbnails'), opt)
|
||||
|
||||
# merge aesthetic scores
|
||||
aesthetic_score_metadata = update_metadata(os.path.join(opt.root, 'aesthetic_scores'), opt)
|
||||
|
||||
# merge render conditions
|
||||
render_cond_metadata = update_metadata(os.path.join(opt.render_cond_root, 'renders_cond'), opt)
|
||||
|
||||
# merge mesh dumped
|
||||
mesh_dumped_metadata = update_metadata(os.path.join(opt.mesh_dump_root, 'mesh_dumps'), opt)
|
||||
|
||||
# merge pbr dumped
|
||||
pbr_dumped_metadata = update_metadata(os.path.join(opt.pbr_dump_root, 'pbr_dumps'), opt)
|
||||
|
||||
# merge asset stats
|
||||
asset_stats_metadata = update_metadata(os.path.join(opt.root, 'asset_stats'), opt)
|
||||
|
||||
# merge dual grid
|
||||
dual_grid_resolutions = []
|
||||
for dir in os.listdir(opt.dual_grid_root):
|
||||
if os.path.isdir(os.path.join(opt.dual_grid_root, dir)) and dir.startswith('dual_grid_'):
|
||||
dual_grid_resolutions.append(int(dir.split('_')[-1]))
|
||||
dual_grid_metadata = {}
|
||||
for res in dual_grid_resolutions:
|
||||
dual_grid_metadata[res] = update_metadata(os.path.join(opt.dual_grid_root, f'dual_grid_{res}'), opt)
|
||||
|
||||
# merge pbr voxelized
|
||||
pbr_voxel_resolutions = []
|
||||
for dir in os.listdir(opt.pbr_voxel_root):
|
||||
if os.path.isdir(os.path.join(opt.pbr_voxel_root, dir)) and dir.startswith('pbr_voxels_'):
|
||||
pbr_voxel_resolutions.append(int(dir.split('_')[-1]))
|
||||
pbr_voxel_metadata = {}
|
||||
for res in pbr_voxel_resolutions:
|
||||
pbr_voxel_metadata[res] = update_metadata(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}'), opt)
|
||||
|
||||
# merge ss latents
|
||||
ss_latent_models = []
|
||||
if os.path.exists(os.path.join(opt.ss_latent_root, 'ss_latents')):
|
||||
ss_latent_models = os.listdir(os.path.join(opt.ss_latent_root, 'ss_latents'))
|
||||
ss_latent_metadata = {}
|
||||
for model in ss_latent_models:
|
||||
ss_latent_metadata[model] = update_metadata(os.path.join(opt.ss_latent_root, f'ss_latents/{model}'), opt)
|
||||
|
||||
# merge shape latents
|
||||
shape_latent_models = []
|
||||
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents')):
|
||||
shape_latent_models = os.listdir(os.path.join(opt.shape_latent_root, 'shape_latents'))
|
||||
shape_latent_metadata = {}
|
||||
for model in shape_latent_models:
|
||||
shape_latent_metadata[model] = update_metadata(os.path.join(opt.shape_latent_root, f'shape_latents/{model}'), opt)
|
||||
|
||||
# merge pbr latents
|
||||
pbr_latent_models = []
|
||||
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents')):
|
||||
pbr_latent_models = os.listdir(os.path.join(opt.pbr_latent_root, 'pbr_latents'))
|
||||
pbr_latent_metadata = {}
|
||||
for model in pbr_latent_models:
|
||||
pbr_latent_metadata[model] = update_metadata(os.path.join(opt.pbr_latent_root, f'pbr_latents/{model}'), opt)
|
||||
|
||||
# statistics
|
||||
num_downloaded = downloaded_metadata['local_path'].count() if downloaded_metadata is not None else 0
|
||||
with open(os.path.join(opt.root, 'statistics.txt'), 'w') as f:
|
||||
f.write('Statistics:\n')
|
||||
f.write(f' - Number of assets: {len(metadata)}\n')
|
||||
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
|
||||
if thumbnail_metadata is not None:
|
||||
f.write(f' - Number of assets with thumbnails: {thumbnail_metadata["thumbnailed"].sum()}\n')
|
||||
if aesthetic_score_metadata is not None:
|
||||
f.write(f' - Number of assets with aesthetic scores: {aesthetic_score_metadata["aesthetic_score"].count()}\n')
|
||||
if render_cond_metadata is not None:
|
||||
f.write(f' - Number of assets with render conditions: {render_cond_metadata["cond_rendered"].count()}\n')
|
||||
if mesh_dumped_metadata is not None:
|
||||
f.write(f' - Number of assets with mesh dumped: {mesh_dumped_metadata["mesh_dumped"].sum()}\n')
|
||||
if pbr_dumped_metadata is not None:
|
||||
f.write(f' - Number of assets with PBR dumped: {pbr_dumped_metadata["pbr_dumped"].sum()}\n')
|
||||
if asset_stats_metadata is not None:
|
||||
f.write(f' - Number of assets with asset stats: {len(asset_stats_metadata)}\n')
|
||||
if len(dual_grid_resolutions) != 0:
|
||||
f.write(f' - Number of assets with dual grid:\n')
|
||||
for res in dual_grid_resolutions:
|
||||
if dual_grid_metadata[res] is not None:
|
||||
f.write(f' - {res}: {dual_grid_metadata[res]["dual_grid_converted"].sum()}\n')
|
||||
if len(pbr_voxel_resolutions) != 0:
|
||||
f.write(f' - Number of assets with PBR voxelization:\n')
|
||||
for res in pbr_voxel_resolutions:
|
||||
if pbr_voxel_metadata[res] is not None:
|
||||
f.write(f' - {res}: {pbr_voxel_metadata[res]["pbr_voxelized"].sum()}\n')
|
||||
if len(ss_latent_models) != 0:
|
||||
f.write(f' - Number of assets with sparse structure latents:\n')
|
||||
for model in ss_latent_models:
|
||||
if ss_latent_metadata[model] is not None:
|
||||
f.write(f' - {model}: {ss_latent_metadata[model]["ss_latent_encoded"].sum()}\n')
|
||||
if len(shape_latent_models) != 0:
|
||||
f.write(f' - Number of assets with shape latents:\n')
|
||||
for model in shape_latent_models:
|
||||
if shape_latent_metadata[model] is not None:
|
||||
f.write(f' - {model}: {shape_latent_metadata[model]["shape_latent_encoded"].sum()}\n')
|
||||
if len(pbr_latent_models) != 0:
|
||||
f.write(f' - Number of assets with PBR latents:\n')
|
||||
for model in pbr_latent_models:
|
||||
if pbr_latent_metadata[model] is not None:
|
||||
f.write(f' - {model}: {pbr_latent_metadata[model]["pbr_latent_encoded"].sum()}\n')
|
||||
|
||||
with open(os.path.join(opt.root, 'statistics.txt'), 'r') as f:
|
||||
print(f.read())
|
||||
64
data_toolkit/download.py
Executable file
64
data_toolkit/download.py
Executable file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import copy
|
||||
import sys
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--download_root', type=str, default=None,
|
||||
help='Directory to download the objects')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--check_only', action='store_true',
|
||||
help='Only check if the objects are already downloaded')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
dataset_utils.add_args(parser)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
opt.download_root = opt.download_root or opt.root
|
||||
|
||||
os.makedirs(opt.root, exist_ok=True)
|
||||
os.makedirs(opt.download_root, exist_ok=True)
|
||||
os.makedirs(os.path.join(opt.download_root, 'raw', 'new_records'), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
if 'local_path' in metadata.columns:
|
||||
metadata = metadata[metadata['local_path'].isna()]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
downloaded = dataset_utils.download(metadata, **opt)
|
||||
downloaded.to_csv(os.path.join(opt.download_root, 'raw', 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
170
data_toolkit/dual_grid.py
Executable file
170
data_toolkit/dual_grid.py
Executable file
@@ -0,0 +1,170 @@
|
||||
import os
|
||||
import sys
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import pickle
|
||||
import o_voxel
|
||||
from easydict import EasyDict as edict
|
||||
from functools import partial
|
||||
|
||||
|
||||
def _dual_grid_mesh(file, metadatum, mesh_dump_root, root):
|
||||
sha256 = metadatum['sha256']
|
||||
try:
|
||||
pack = {'sha256': sha256}
|
||||
data = None
|
||||
for res in opt.resolution:
|
||||
need_process = False
|
||||
|
||||
# check if already processed
|
||||
if os.path.exists(os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz')):
|
||||
try:
|
||||
info = o_voxel.io.read_vxz_info(os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz'))
|
||||
pack[f'dual_grid_converted_{res}'] = True
|
||||
pack[f'dual_grid_size_{res}'] = info['num_voxel']
|
||||
except Exception as e:
|
||||
print(f'Error reading {sha256}.vxz: {e}')
|
||||
need_process = True
|
||||
else:
|
||||
need_process = True
|
||||
|
||||
# process mesh
|
||||
if need_process:
|
||||
if data is None:
|
||||
with open(os.path.join(mesh_dump_root, 'mesh_dumps', f'{sha256}.pickle'), 'rb') as f:
|
||||
dump = pickle.load(f)
|
||||
start = 0
|
||||
vertices = []
|
||||
faces = []
|
||||
for obj in dump['objects']:
|
||||
if obj['vertices'].size == 0 or obj['faces'].size == 0:
|
||||
continue
|
||||
vertices.append(obj['vertices'])
|
||||
faces.append(obj['faces'] + start)
|
||||
start += len(obj['vertices'])
|
||||
vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
|
||||
faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
|
||||
vertices_min = vertices.min(dim=0)[0]
|
||||
vertices_max = vertices.max(dim=0)[0]
|
||||
center = (vertices_min + vertices_max) / 2
|
||||
scale = 0.99999 / (vertices_max - vertices_min).max()
|
||||
vertices = (vertices - center) * scale
|
||||
assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
|
||||
data = {'vertices': vertices, 'faces': faces}
|
||||
|
||||
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
|
||||
**data,
|
||||
grid_size=res,
|
||||
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
|
||||
face_weight=1.0,
|
||||
boundary_weight=0.2,
|
||||
regularization_weight=1e-2,
|
||||
timing=False,
|
||||
)
|
||||
dual_vertices = dual_vertices * res - voxel_indices
|
||||
assert torch.all(dual_vertices >= -1e-3) and torch.all(dual_vertices <= 1+1e-3), 'dual_vertices out of range'
|
||||
dual_vertices = torch.clamp(dual_vertices, 0, 1)
|
||||
dual_vertices = (dual_vertices * 255).type(torch.uint8)
|
||||
intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
|
||||
|
||||
o_voxel.io.write_vxz(
|
||||
os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz'),
|
||||
voxel_indices,
|
||||
{'vertices': dual_vertices, 'intersected': intersected},
|
||||
)
|
||||
|
||||
pack[f'dual_grid_converted_{res}'] = True
|
||||
pack[f'dual_grid_size_{res}'] = len(dual_vertices)
|
||||
|
||||
return pack
|
||||
except Exception as e:
|
||||
print(f'Error processing {sha256}: {e}')
|
||||
return {'sha256': sha256, 'error': str(e)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--mesh_dump_root', type=str, default=None,
|
||||
help='Directory to load mesh dumps')
|
||||
parser.add_argument('--dual_grid_root', type=str, default=None,
|
||||
help='Directory to save dual grids')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
dataset_utils.add_args(parser)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--resolution', type=str, default=256)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
parser.add_argument('--max_workers', type=int, default=0)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
opt.resolution = [int(x) for x in opt.resolution.split(',')]
|
||||
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
|
||||
opt.dual_grid_root = opt.dual_grid_root or opt.root
|
||||
|
||||
for res in opt.resolution:
|
||||
os.makedirs(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'new_records'), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')).set_index('sha256'))
|
||||
for res in opt.resolution:
|
||||
if os.path.exists(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'metadata.csv')):
|
||||
dual_grid_metadata = pd.read_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'metadata.csv')).set_index('sha256')
|
||||
dual_grid_metadata = dual_grid_metadata.rename(columns={'dual_grid_converted': f'dual_grid_converted_{res}', 'dual_grid_size': f'dual_grid_size_{res}'})
|
||||
metadata = metadata.combine_first(dual_grid_metadata)
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
metadata = metadata[metadata['mesh_dumped'] == True]
|
||||
mask = np.zeros(len(metadata), dtype=bool)
|
||||
for res in opt.resolution:
|
||||
if f'dual_grid_converted_{res}' in metadata.columns:
|
||||
mask |= metadata[f'dual_grid_converted_{res}'] != True
|
||||
else:
|
||||
mask[:] = True
|
||||
break
|
||||
metadata = metadata[mask]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
func = partial(_dual_grid_mesh, root=opt.dual_grid_root, mesh_dump_root=opt.mesh_dump_root)
|
||||
dual_grids = dataset_utils.foreach_instance(metadata, None, func, max_workers=opt.max_workers, no_file=True, desc='Dual griding')
|
||||
if 'error' in dual_grids.columns:
|
||||
errors = dual_grids[dual_grids[f'error'].notna()]
|
||||
with open('errors.txt', 'w') as f:
|
||||
f.write('\n'.join(errors['sha256'].tolist()))
|
||||
for res in opt.resolution:
|
||||
if f'dual_grid_converted_{res}' in dual_grids.columns:
|
||||
dual_grid_metadata = dual_grids[dual_grids[f'dual_grid_converted_{res}'] == True]
|
||||
if len(dual_grid_metadata) > 0:
|
||||
dual_grid_metadata = dual_grid_metadata[['sha256', f'dual_grid_converted_{res}', f'dual_grid_size_{res}']]
|
||||
dual_grid_metadata = dual_grid_metadata.rename(columns={f'dual_grid_converted_{res}': 'dual_grid_converted', f'dual_grid_size_{res}': 'dual_grid_size'})
|
||||
dual_grid_metadata.to_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
|
||||
127
data_toolkit/dump_mesh.py
Executable file
127
data_toolkit/dump_mesh.py
Executable file
@@ -0,0 +1,127 @@
|
||||
import os
|
||||
import shutil
|
||||
import copy
|
||||
import sys
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from easydict import EasyDict as edict
|
||||
from functools import partial
|
||||
from subprocess import DEVNULL, call
|
||||
import numpy as np
|
||||
import tempfile
|
||||
|
||||
|
||||
BLENDER_LINK = 'https://ftp.halifax.rwth-aachen.de/blender/release/Blender4.5/blender-4.5.1-linux-x64.tar.xz'
|
||||
BLENDER_INSTALLATION_PATH = '/tmp'
|
||||
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64/blender'
|
||||
|
||||
def _install_blender():
|
||||
if not os.path.exists(BLENDER_PATH):
|
||||
os.system('sudo apt-get update')
|
||||
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
|
||||
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
||||
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
||||
|
||||
|
||||
def _dump_mesh(file_path, metadatum, root):
|
||||
sha256 = metadatum['sha256']
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_path = os.path.join(tmp_dir, f'{sha256}.pickle')
|
||||
output_path = os.path.join(root, 'mesh_dumps', f'{sha256}.pickle')
|
||||
args = [
|
||||
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'dump_mesh.py'),
|
||||
'--',
|
||||
'--object', os.path.expanduser(file_path),
|
||||
'--output_path', os.path.expanduser(temp_path)
|
||||
]
|
||||
if file_path.endswith('.blend'):
|
||||
args.insert(1, file_path)
|
||||
|
||||
call(args, stdout=DEVNULL, stderr=DEVNULL)
|
||||
|
||||
if os.path.exists(temp_path):
|
||||
shutil.move(temp_path, output_path)
|
||||
return {'sha256': sha256, 'mesh_dumped': True}
|
||||
else:
|
||||
if os.path.exists(temp_path + '_error.txt'):
|
||||
with open(temp_path + '_error.txt', 'r') as f:
|
||||
error_msg = f.read()
|
||||
raise ValueError(f'Failed to dump mesh. File {file_path}. Error message: {error_msg}')
|
||||
else:
|
||||
raise ValueError(f'Failed to dump mesh. File {file_path}.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--download_root', type=str, default=None,
|
||||
help='Directory to save the downloaded files')
|
||||
parser.add_argument('--mesh_dump_root', type=str, default=None,
|
||||
help='Directory to save the mesh dumps')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
dataset_utils.add_args(parser)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
parser.add_argument('--max_workers', type=int, default=0)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
opt.download_root = opt.download_root or opt.root
|
||||
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
|
||||
|
||||
os.makedirs(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'new_records'), exist_ok=True)
|
||||
|
||||
# install blender
|
||||
print('Checking blender...', flush=True)
|
||||
_install_blender()
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
metadata = metadata[metadata['local_path'].notna()]
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
if 'mesh_dumped' in metadata.columns:
|
||||
metadata = metadata[metadata['mesh_dumped'] != True]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
sha256_list = os.listdir(os.path.join(opt.mesh_dump_root, 'mesh_dumps'))
|
||||
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.pickle')]
|
||||
for sha256 in sha256_list:
|
||||
records.append({'sha256': sha256, 'mesh_dumped': True})
|
||||
print(f'Found {len(sha256_list)} dumped mesh')
|
||||
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
func = partial(_dump_mesh, root=opt.mesh_dump_root)
|
||||
mesh_dumped = dataset_utils.foreach_instance(metadata, opt.download_root, func, max_workers=opt.max_workers, desc='Dumping mesh')
|
||||
mesh_dumped = pd.concat([mesh_dumped, pd.DataFrame.from_records(records)])
|
||||
mesh_dumped.to_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
128
data_toolkit/dump_pbr.py
Executable file
128
data_toolkit/dump_pbr.py
Executable file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import shutil
|
||||
import copy
|
||||
import sys
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from easydict import EasyDict as edict
|
||||
from functools import partial
|
||||
from subprocess import DEVNULL, call
|
||||
import numpy as np
|
||||
import tempfile
|
||||
|
||||
|
||||
BLENDER_LINK = 'https://ftp.halifax.rwth-aachen.de/blender/release/Blender4.5/blender-4.5.1-linux-x64.tar.xz'
|
||||
BLENDER_INSTALLATION_PATH = '/tmp'
|
||||
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64/blender'
|
||||
|
||||
def _install_blender():
|
||||
if not os.path.exists(BLENDER_PATH):
|
||||
os.system('sudo apt-get update')
|
||||
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
|
||||
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
||||
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
||||
os.system(f'{BLENDER_PATH} -b --python {os.path.join(os.path.dirname(__file__), "blender_script", "install_pillow.py")}')
|
||||
|
||||
|
||||
def _dump_pbr(file_path, metadatum, root):
|
||||
sha256 = metadatum['sha256']
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_path = os.path.join(tmp_dir, f'{sha256}.pickle')
|
||||
output_path = os.path.join(root, 'pbr_dumps', f'{sha256}.pickle')
|
||||
args = [
|
||||
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'dump_pbr.py'),
|
||||
'--',
|
||||
'--object', os.path.expanduser(file_path),
|
||||
'--output_path', os.path.expanduser(temp_path)
|
||||
]
|
||||
if file_path.endswith('.blend'):
|
||||
args.insert(1, file_path)
|
||||
|
||||
call(args, stdout=DEVNULL, stderr=DEVNULL)
|
||||
|
||||
if os.path.exists(temp_path):
|
||||
shutil.move(temp_path, output_path)
|
||||
return {'sha256': sha256, 'pbr_dumped': True}
|
||||
else:
|
||||
if os.path.exists(temp_path + '_error.txt'):
|
||||
with open(temp_path + '_error.txt', 'r') as f:
|
||||
error_msg = f.read()
|
||||
raise ValueError(f'Failed to dump PBR. File {file_path}. Error message: {error_msg}')
|
||||
else:
|
||||
raise ValueError(f'Failed to dump PBR. File {file_path}.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--download_root', type=str, default=None,
|
||||
help='Directory to save the downloaded files')
|
||||
parser.add_argument('--pbr_dump_root', type=str, default=None,
|
||||
help='Directory to save the mesh dumps')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
dataset_utils.add_args(parser)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
parser.add_argument('--max_workers', type=int, default=0)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
opt.download_root = opt.download_root or opt.root
|
||||
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
|
||||
|
||||
os.makedirs(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'new_records'), exist_ok=True)
|
||||
|
||||
# install blender
|
||||
print('Checking blender...', flush=True)
|
||||
_install_blender()
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
metadata = metadata[metadata['local_path'].notna()]
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
if 'pbr_dumped' in metadata.columns:
|
||||
metadata = metadata[metadata['pbr_dumped'] != True]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
sha256_list = os.listdir(os.path.join(opt.pbr_dump_root, 'pbr_dumps'))
|
||||
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.pickle')]
|
||||
for sha256 in sha256_list:
|
||||
records.append({'sha256': sha256, 'pbr_dumped': True})
|
||||
print(f'Found {len(sha256_list)} dumped PBRs')
|
||||
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
func = partial(_dump_pbr, root=opt.pbr_dump_root)
|
||||
pbr_dumped = dataset_utils.foreach_instance(metadata, opt.download_root, func, max_workers=opt.max_workers, desc='Dumping PBR')
|
||||
pbr_dumped = pd.concat([pbr_dumped, pd.DataFrame.from_records(records)])
|
||||
pbr_dumped.to_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
181
data_toolkit/encode_pbr_latent.py
Normal file
181
data_toolkit/encode_pbr_latent.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import o_voxel
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
|
||||
import trellis2.models as models
|
||||
import trellis2.modules.sparse as sp
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
def is_valid_sparse_tensor(tensor):
|
||||
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
|
||||
|
||||
def clear_cuda_error():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--pbr_voxel_root', type=str, default=None,
|
||||
help='Directory to save the pbr voxel files')
|
||||
parser.add_argument('--pbr_latent_root', type=str, default=None,
|
||||
help='Directory to save the pbr latent files')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--resolution', type=int, default=1024,
|
||||
help='Sparse voxel resolution')
|
||||
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16',
|
||||
help='Pretrained encoder model')
|
||||
parser.add_argument('--model_root', type=str,
|
||||
help='Root directory of models')
|
||||
parser.add_argument('--enc_model', type=str,
|
||||
help='Encoder model. if specified, use this model instead of pretrained model')
|
||||
parser.add_argument('--ckpt', type=str,
|
||||
help='Checkpoint to load')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args()
|
||||
opt = edict(vars(opt))
|
||||
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
|
||||
opt.pbr_latent_root = opt.pbr_latent_root or opt.root
|
||||
|
||||
if opt.enc_model is None:
|
||||
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
|
||||
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
||||
else:
|
||||
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
|
||||
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
||||
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
||||
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
||||
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
||||
encoder.eval()
|
||||
print(f'Loaded model from {ckpt_path}')
|
||||
|
||||
os.makedirs(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'new_records'), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name,'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
metadata = metadata[metadata['pbr_voxelized'] == True]
|
||||
if 'pbr_latent_encoded' in metadata.columns:
|
||||
metadata = metadata[metadata['pbr_latent_encoded'] != True]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
|
||||
def check_sha256(sha256):
|
||||
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz')):
|
||||
coords = np.load(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz'))['coords']
|
||||
records.append({'sha256': sha256, 'pbr_latent_encoded': True, 'pbr_latent_tokens': coords.shape[0]})
|
||||
pbar.update()
|
||||
executor.map(check_sha256, metadata['sha256'].values)
|
||||
executor.shutdown(wait=True)
|
||||
existing_sha256 = set(r['sha256'] for r in records)
|
||||
print(f'Found {len(existing_sha256)} processed objects')
|
||||
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
sha256s = list(metadata['sha256'].values)
|
||||
load_queue = Queue(maxsize=32)
|
||||
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
||||
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
||||
|
||||
def loader(sha256):
|
||||
try:
|
||||
attrs = ['base_color', 'metallic', 'roughness', 'alpha']
|
||||
coords, attr = o_voxel.io.read_vxz(
|
||||
os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}', f'{sha256}.vxz'),
|
||||
num_threads=4
|
||||
)
|
||||
feats = torch.concat([attr[k] for k in attrs], dim=-1) / 255.0 * 2 - 1
|
||||
x = sp.SparseTensor(
|
||||
feats.float(),
|
||||
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
|
||||
)
|
||||
load_queue.put((sha256, x))
|
||||
except Exception as e:
|
||||
print(f"[Loader Error] {sha256}: {e}")
|
||||
load_queue.put((sha256, None))
|
||||
|
||||
loader_executor.map(loader, sha256s)
|
||||
|
||||
def saver(sha256, pack):
|
||||
save_path = os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz')
|
||||
np.savez_compressed(save_path, **pack)
|
||||
records.append({'sha256': sha256, 'pbr_latent_encoded': True, 'pbr_latent_tokens': pack['coords'].shape[0]})
|
||||
|
||||
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
||||
try:
|
||||
sha256, voxels = load_queue.get()
|
||||
if voxels is None:
|
||||
print(f"[Skip] {sha256}: Failed to load input")
|
||||
continue
|
||||
|
||||
num_voxels = voxels.feats.shape[0]
|
||||
|
||||
# NaN/Inf
|
||||
if not (is_valid_sparse_tensor(voxels)):
|
||||
print(f"[Skip] {sha256}: NaN/Inf in input")
|
||||
continue
|
||||
|
||||
z = encoder(voxels.cuda())
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if not torch.isfinite(z.feats).all():
|
||||
print(f"[Skip] {sha256}: Non-finite latent in z.feats")
|
||||
clear_cuda_error()
|
||||
continue
|
||||
|
||||
pack = {
|
||||
'feats': z.feats.cpu().numpy().astype(np.float32),
|
||||
'coords': z.coords[:, 1:].cpu().numpy().astype(np.uint8),
|
||||
}
|
||||
saver_executor.submit(saver, sha256, pack)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Error] {sha256} ({num_voxels} voxels): {e}")
|
||||
clear_cuda_error()
|
||||
continue
|
||||
|
||||
saver_executor.shutdown(wait=True)
|
||||
|
||||
records = pd.DataFrame.from_records(records)
|
||||
records.to_csv(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
184
data_toolkit/encode_shape_latent.py
Normal file
184
data_toolkit/encode_shape_latent.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import o_voxel
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
|
||||
import trellis2.models as models
|
||||
import trellis2.modules.sparse as sp
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
def is_valid_sparse_tensor(tensor):
|
||||
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
|
||||
|
||||
def clear_cuda_error():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--dual_grid_root', type=str, default=None,
|
||||
help='Directory to save the dual grids')
|
||||
parser.add_argument('--shape_latent_root', type=str, default=None,
|
||||
help='Directory to save the shape latent files')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--resolution', type=int, default=1024,
|
||||
help='Sparse voxel resolution')
|
||||
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16',
|
||||
help='Pretrained encoder model')
|
||||
parser.add_argument('--model_root', type=str,
|
||||
help='Root directory of models')
|
||||
parser.add_argument('--enc_model', type=str,
|
||||
help='Encoder model. if specified, use this model instead of pretrained model')
|
||||
parser.add_argument('--ckpt', type=str,
|
||||
help='Checkpoint to load')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args()
|
||||
opt = edict(vars(opt))
|
||||
opt.dual_grid_root = opt.dual_grid_root or opt.root
|
||||
opt.shape_latent_root = opt.shape_latent_root or opt.root
|
||||
|
||||
if opt.enc_model is None:
|
||||
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
|
||||
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
||||
else:
|
||||
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
|
||||
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
||||
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
||||
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
||||
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
||||
encoder.eval()
|
||||
print(f'Loaded model from {ckpt_path}')
|
||||
|
||||
os.makedirs(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'new_records'), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name,'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
metadata = metadata[metadata['dual_grid_converted'] == True]
|
||||
if 'shape_latent_encoded' in metadata.columns:
|
||||
metadata = metadata[metadata['shape_latent_encoded'] != True]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
|
||||
def check_sha256(sha256):
|
||||
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz')):
|
||||
coords = np.load(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz'))['coords']
|
||||
records.append({'sha256': sha256, 'shape_latent_encoded': True, 'shape_latent_tokens': coords.shape[0]})
|
||||
pbar.update()
|
||||
executor.map(check_sha256, metadata['sha256'].values)
|
||||
executor.shutdown(wait=True)
|
||||
existing_sha256 = set(r['sha256'] for r in records)
|
||||
print(f'Found {len(existing_sha256)} processed objects')
|
||||
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
sha256s = list(metadata['sha256'].values)
|
||||
load_queue = Queue(maxsize=32)
|
||||
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
||||
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
||||
|
||||
def loader(sha256):
|
||||
try:
|
||||
coords, attr = o_voxel.io.read_vxz(
|
||||
os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}', f'{sha256}.vxz'),
|
||||
num_threads=4
|
||||
)
|
||||
vertices = sp.SparseTensor(
|
||||
(attr['vertices'] / 255.0).float(),
|
||||
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
|
||||
)
|
||||
intersected = vertices.replace(torch.cat([
|
||||
attr['intersected'] % 2,
|
||||
attr['intersected'] // 2 % 2,
|
||||
attr['intersected'] // 4 % 2,
|
||||
], dim=-1).bool())
|
||||
load_queue.put((sha256, vertices, intersected))
|
||||
except Exception as e:
|
||||
print(f"[Loader Error] {sha256}: {e}")
|
||||
load_queue.put((sha256, None, None))
|
||||
|
||||
loader_executor.map(loader, sha256s)
|
||||
|
||||
def saver(sha256, pack):
|
||||
save_path = os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz')
|
||||
np.savez_compressed(save_path, **pack)
|
||||
records.append({'sha256': sha256, 'shape_latent_encoded': True, 'shape_latent_tokens': pack['coords'].shape[0]})
|
||||
|
||||
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
||||
try:
|
||||
sha256, vertices, intersected = load_queue.get()
|
||||
if vertices is None or intersected is None:
|
||||
print(f"[Skip] {sha256}: Failed to load input")
|
||||
continue
|
||||
|
||||
num_voxels = vertices.feats.shape[0]
|
||||
|
||||
# NaN/Inf
|
||||
if not (is_valid_sparse_tensor(vertices) and is_valid_sparse_tensor(intersected)):
|
||||
print(f"[Skip] {sha256}: NaN/Inf in input")
|
||||
continue
|
||||
|
||||
z = encoder(vertices.cuda(), intersected.cuda())
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if not torch.isfinite(z.feats).all():
|
||||
print(f"[Skip] {sha256}: Non-finite latent in z.feats")
|
||||
clear_cuda_error()
|
||||
continue
|
||||
|
||||
pack = {
|
||||
'feats': z.feats.cpu().numpy().astype(np.float32),
|
||||
'coords': z.coords[:, 1:].cpu().numpy().astype(np.uint8),
|
||||
}
|
||||
saver_executor.submit(saver, sha256, pack)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Error] {sha256} ({num_voxels} voxels): {e}")
|
||||
clear_cuda_error()
|
||||
continue
|
||||
|
||||
saver_executor.shutdown(wait=True)
|
||||
|
||||
records = pd.DataFrame.from_records(records)
|
||||
records.to_csv(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
163
data_toolkit/encode_ss_latent.py
Normal file
163
data_toolkit/encode_ss_latent.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
|
||||
import trellis2.models as models
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
def is_valid_sparse_tensor(tensor):
|
||||
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
|
||||
|
||||
def clear_cuda_error():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--shape_latent_root', type=str, default=None,
|
||||
help='Directory to save the shape latent files')
|
||||
parser.add_argument('--ss_latent_root', type=str, default=None,
|
||||
help='Directory to save the shape latent files')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--resolution', type=int, default=64,
|
||||
help='Sparse voxel resolution')
|
||||
parser.add_argument('--shape_latent_name', type=str, default=None,
|
||||
help='Name of the shape latent files')
|
||||
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
|
||||
help='Pretrained encoder model')
|
||||
parser.add_argument('--model_root', type=str,
|
||||
help='Root directory of models')
|
||||
parser.add_argument('--enc_model', type=str,
|
||||
help='Encoder model. if specified, use this model instead of pretrained model')
|
||||
parser.add_argument('--ckpt', type=str,
|
||||
help='Checkpoint to load')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args()
|
||||
opt = edict(vars(opt))
|
||||
opt.shape_latent_root = opt.shape_latent_root or opt.root
|
||||
opt.ss_latent_root = opt.ss_latent_root or opt.root
|
||||
|
||||
if opt.enc_model is None:
|
||||
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
|
||||
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
||||
else:
|
||||
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
|
||||
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
||||
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
||||
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
||||
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
||||
encoder.eval()
|
||||
print(f'Loaded model from {ckpt_path}')
|
||||
|
||||
os.makedirs(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records'), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name,'metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.ss_latent_root,'ss_latents', latent_name, 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.ss_latent_root,'ss_latents', latent_name,'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
metadata = metadata[metadata['shape_latent_encoded'] == True]
|
||||
if 'ss_latent_encoded' in metadata.columns:
|
||||
metadata = metadata[metadata['ss_latent_encoded'] != True]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
sha256_list = os.listdir(os.path.join(opt.ss_latent_root, 'ss_latents'))
|
||||
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.npz')]
|
||||
for sha256 in sha256_list:
|
||||
records.append({'sha256': sha256, 'ss_latent_encoded': True})
|
||||
print(f'Found {len(sha256_list)} processed objects')
|
||||
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
sha256s = list(metadata['sha256'].values)
|
||||
load_queue = Queue(maxsize=32)
|
||||
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
||||
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
||||
|
||||
def loader(sha256):
|
||||
try:
|
||||
coords = np.load(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, f'{sha256}.npz'))['coords']
|
||||
assert np.all(coords < opt.resolution), f"{sha256}: Invalid coords"
|
||||
coords = torch.from_numpy(coords).long()
|
||||
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
|
||||
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
||||
load_queue.put((sha256, ss))
|
||||
except Exception as e:
|
||||
print(f"[Loader Error] {sha256}: {e}")
|
||||
load_queue.put((sha256, None))
|
||||
|
||||
loader_executor.map(loader, sha256s)
|
||||
|
||||
def saver(sha256, pack):
|
||||
save_path = os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, f'{sha256}.npz')
|
||||
np.savez_compressed(save_path, **pack)
|
||||
records.append({'sha256': sha256, 'ss_latent_encoded': True})
|
||||
|
||||
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
||||
try:
|
||||
sha256, ss = load_queue.get()
|
||||
if ss is None:
|
||||
print(f"[Skip] {sha256}: Failed to load input")
|
||||
continue
|
||||
|
||||
ss = ss.cuda()[None].float()
|
||||
z = encoder(ss, sample_posterior=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if not torch.isfinite(z).all():
|
||||
print(f"[Skip] {sha256}: Non-finite latent")
|
||||
clear_cuda_error()
|
||||
continue
|
||||
|
||||
pack = {
|
||||
'z': z[0].cpu().numpy(),
|
||||
}
|
||||
saver_executor.submit(saver, sha256, pack)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Error] {sha256}: {e}")
|
||||
clear_cuda_error()
|
||||
continue
|
||||
|
||||
saver_executor.shutdown(wait=True)
|
||||
|
||||
records = pd.DataFrame.from_records(records)
|
||||
records.to_csv(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
146
data_toolkit/render_cond.py
Normal file
146
data_toolkit/render_cond.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
import sys
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from easydict import EasyDict as edict
|
||||
from functools import partial
|
||||
from subprocess import DEVNULL, call
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from utils import sphere_hammersley_sequence
|
||||
|
||||
|
||||
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
||||
BLENDER_INSTALLATION_PATH = '/tmp'
|
||||
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
||||
|
||||
def _install_blender():
|
||||
if not os.path.exists(BLENDER_PATH):
|
||||
os.system('sudo apt-get update')
|
||||
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
|
||||
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
||||
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
||||
|
||||
|
||||
def _render_cond(file_path, metadatum, root, num_cond_views):
|
||||
sha256 = metadatum['sha256']
|
||||
# Build conditional view camera
|
||||
yaws = []
|
||||
pitchs = []
|
||||
offset = (np.random.rand(), np.random.rand())
|
||||
for i in range(num_cond_views):
|
||||
y, p = sphere_hammersley_sequence(i, num_cond_views, offset)
|
||||
yaws.append(y)
|
||||
pitchs.append(p)
|
||||
fov_min, fov_max = 10, 70
|
||||
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
|
||||
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
|
||||
k_min = 1 / radius_max**2
|
||||
k_max = 1 / radius_min**2
|
||||
ks = np.random.uniform(k_min, k_max, (1000000,))
|
||||
radius = [1 / np.sqrt(k) for k in ks]
|
||||
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
|
||||
cond_views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
||||
|
||||
args = [
|
||||
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render_cond.py'),
|
||||
'--',
|
||||
'--object', os.path.expanduser(file_path),
|
||||
'--cond_views', json.dumps(cond_views),
|
||||
'--cond_resolution', '1024',
|
||||
'--cond_output_folder', os.path.join(root, 'renders_cond', sha256),
|
||||
'--engine', 'CYCLES',
|
||||
]
|
||||
if file_path.endswith('.blend'):
|
||||
args.insert(1, file_path)
|
||||
|
||||
call(args, stdout=DEVNULL, stderr=DEVNULL)
|
||||
|
||||
if os.path.exists(os.path.join(root, 'renders_cond', sha256, 'transforms.json')):
|
||||
return {'sha256': sha256, 'cond_rendered': True}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--download_root', type=str, default=None,
|
||||
help='Directory to save the downloaded files')
|
||||
parser.add_argument('--render_cond_root', type=str, default=None,
|
||||
help='Directory to save the mesh dumps')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--num_cond_views', type=int, default=16,
|
||||
help='Number of conditional views to render')
|
||||
dataset_utils.add_args(parser)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
parser.add_argument('--max_workers', type=int, default=8)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
opt.download_root = opt.download_root or opt.root
|
||||
opt.render_cond_root = opt.render_cond_root or opt.root
|
||||
|
||||
os.makedirs(os.path.join(opt.render_cond_root, 'renders_cond', 'new_records'), exist_ok=True)
|
||||
|
||||
# install blender
|
||||
print('Checking blender...', flush=True)
|
||||
_install_blender()
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.render_cond_root, 'renders_cond', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.render_cond_root, 'renders_cond', 'metadata.csv')).set_index('sha256'))
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
metadata = metadata[metadata['local_path'].notna()]
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
if 'cond_rendered' in metadata.columns:
|
||||
metadata = metadata[(metadata['cond_rendered'] != True)]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
|
||||
def check_sha256(sha256):
|
||||
if os.path.exists(os.path.join(opt.render_cond_root, 'renders_cond', sha256, 'transforms.json')):
|
||||
records.append({'sha256': sha256, 'cond_rendered': True})
|
||||
pbar.update()
|
||||
executor.map(check_sha256, metadata['sha256'].values)
|
||||
executor.shutdown(wait=True)
|
||||
existing_sha256 = set(r['sha256'] for r in records)
|
||||
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
func = partial(_render_cond, root=opt.render_cond_root, num_cond_views=opt.num_cond_views)
|
||||
cond_rendered = dataset_utils.foreach_instance(metadata, opt.render_cond_root, func, max_workers=opt.max_workers, desc='Rendering objects')
|
||||
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
|
||||
cond_rendered.to_csv(os.path.join(opt.render_cond_root, 'renders_cond', 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
1
data_toolkit/setup.sh
Executable file
1
data_toolkit/setup.sh
Executable file
@@ -0,0 +1 @@
|
||||
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub[cli] open_clip_torch
|
||||
440
data_toolkit/utils.py
Executable file
440
data_toolkit/utils.py
Executable file
@@ -0,0 +1,440 @@
|
||||
from typing import *
|
||||
import hashlib
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def get_file_hash(file: str) -> str:
|
||||
sha256 = hashlib.sha256()
|
||||
# Read the file from the path
|
||||
with open(file, "rb") as f:
|
||||
# Update the hash with the file content
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256.update(byte_block)
|
||||
return sha256.hexdigest()
|
||||
|
||||
# ===============LOW DISCREPANCY SEQUENCES================
|
||||
|
||||
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
||||
|
||||
def radical_inverse(base, n):
|
||||
val = 0
|
||||
inv_base = 1.0 / base
|
||||
inv_base_n = inv_base
|
||||
while n > 0:
|
||||
digit = n % base
|
||||
val += digit * inv_base_n
|
||||
n //= base
|
||||
inv_base_n *= inv_base
|
||||
return val
|
||||
|
||||
def halton_sequence(dim, n):
|
||||
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
||||
|
||||
def hammersley_sequence(dim, n, num_samples):
|
||||
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||
|
||||
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
|
||||
u, v = hammersley_sequence(2, n, num_samples)
|
||||
u += offset[0] / num_samples
|
||||
v += offset[1]
|
||||
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
||||
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
||||
phi = v * 2 * np.pi
|
||||
return [phi, theta]
|
||||
|
||||
# ==============PLY IO===============
|
||||
import struct
|
||||
import re
|
||||
import torch
|
||||
|
||||
def read_ply(filename):
|
||||
"""
|
||||
Read a PLY file and return vertices, triangle faces, and quad faces.
|
||||
|
||||
Args:
|
||||
filename (str): The file path to read from.
|
||||
|
||||
Returns:
|
||||
vertices (torch.Tensor): Tensor of shape [N, 3] containing vertex positions.
|
||||
tris (torch.Tensor): Tensor of shape [M, 3] containing triangle face indices (empty if none).
|
||||
quads (torch.Tensor): Tensor of shape [K, 4] containing quad face indices (empty if none).
|
||||
"""
|
||||
with open(filename, 'rb') as f:
|
||||
# Read the header until 'end_header' is encountered
|
||||
header_bytes = b""
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
raise ValueError("PLY header not found")
|
||||
header_bytes += line
|
||||
if b"end_header" in line:
|
||||
break
|
||||
header = header_bytes.decode('utf-8')
|
||||
|
||||
# Determine if the file is in ASCII or binary format
|
||||
is_ascii = "ascii" in header
|
||||
|
||||
# Extract the number of vertices and faces from the header using regex
|
||||
vertex_match = re.search(r'element vertex (\d+)', header)
|
||||
if vertex_match:
|
||||
num_vertices = int(vertex_match.group(1))
|
||||
else:
|
||||
raise ValueError("Vertex count not found in header")
|
||||
|
||||
face_match = re.search(r'element face (\d+)', header)
|
||||
if face_match:
|
||||
num_faces = int(face_match.group(1))
|
||||
else:
|
||||
raise ValueError("Face count not found in header")
|
||||
|
||||
vertices = []
|
||||
tris = []
|
||||
quads = []
|
||||
|
||||
if is_ascii:
|
||||
# For ASCII format, read each line of vertex data (each line contains 3 floats)
|
||||
for _ in range(num_vertices):
|
||||
line = f.readline().decode('utf-8').strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split()
|
||||
vertices.append([float(parts[0]), float(parts[1]), float(parts[2])])
|
||||
|
||||
# Read face data, where the first number indicates the number of vertices for the face
|
||||
for _ in range(num_faces):
|
||||
line = f.readline().decode('utf-8').strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split()
|
||||
count = int(parts[0])
|
||||
indices = list(map(int, parts[1:]))
|
||||
if count == 3:
|
||||
tris.append(indices)
|
||||
elif count == 4:
|
||||
quads.append(indices)
|
||||
else:
|
||||
# Skip faces with other numbers of vertices (can be extended as needed)
|
||||
pass
|
||||
else:
|
||||
# For binary format: read directly from the binary stream
|
||||
# Each vertex consists of 3 floats (12 bytes per vertex)
|
||||
for _ in range(num_vertices):
|
||||
data = f.read(12)
|
||||
if len(data) < 12:
|
||||
raise ValueError("Insufficient vertex data")
|
||||
v = struct.unpack('<fff', data)
|
||||
vertices.append(v)
|
||||
|
||||
# Read face data from the binary stream
|
||||
for _ in range(num_faces):
|
||||
# First, read 1 byte indicating the number of vertices in the face
|
||||
count_data = f.read(1)
|
||||
if len(count_data) < 1:
|
||||
raise ValueError("Failed to read face vertex count")
|
||||
count = struct.unpack('<B', count_data)[0]
|
||||
if count == 3:
|
||||
data = f.read(12) # 3 * 4 bytes
|
||||
if len(data) < 12:
|
||||
raise ValueError("Insufficient data for triangle face")
|
||||
indices = struct.unpack('<3i', data)
|
||||
tris.append(indices)
|
||||
elif count == 4:
|
||||
data = f.read(16) # 4 * 4 bytes
|
||||
if len(data) < 16:
|
||||
raise ValueError("Insufficient data for quad face")
|
||||
indices = struct.unpack('<4i', data)
|
||||
quads.append(indices)
|
||||
else:
|
||||
# For faces with a different number of vertices, read count*4 bytes
|
||||
data = f.read(count * 4)
|
||||
# Skip or extend processing as needed
|
||||
raise ValueError(f"Unsupported face with {count} vertices")
|
||||
|
||||
# Convert lists to torch.Tensor
|
||||
vertices = torch.tensor(vertices, dtype=torch.float32)
|
||||
tris = torch.tensor(tris, dtype=torch.int32) if len(tris) > 0 else torch.empty((0, 3), dtype=torch.int32)
|
||||
quads = torch.tensor(quads, dtype=torch.int32) if len(quads) > 0 else torch.empty((0, 4), dtype=torch.int32)
|
||||
|
||||
return vertices, tris, quads
|
||||
|
||||
|
||||
def write_ply(filename, vertices, tris, quads, ascii=False):
|
||||
"""
|
||||
Write a mesh to a PLY file, with the option to save in ASCII or binary format.
|
||||
|
||||
Args:
|
||||
filename (str): The filename to write to.
|
||||
vertices (torch.Tensor): [N, 3] The vertex positions.
|
||||
tris (torch.Tensor): [M, 3] The triangle indices.
|
||||
quads (torch.Tensor): [K, 4] The quad indices.
|
||||
ascii (bool): If True, write in ASCII format. If False, write in binary format.
|
||||
"""
|
||||
# Convert torch tensors to numpy arrays
|
||||
vertices = vertices.numpy()
|
||||
tris = tris.numpy()
|
||||
quads = quads.numpy()
|
||||
|
||||
# Prepare the header
|
||||
num_vertices = len(vertices)
|
||||
num_faces = len(tris) + len(quads)
|
||||
|
||||
# Vertex properties
|
||||
vertex_header = "property float x\nproperty float y\nproperty float z"
|
||||
|
||||
# Face properties (the number of vertices per face is variable)
|
||||
face_header = "property list uchar int vertex_index"
|
||||
|
||||
# Start writing the PLY header
|
||||
header = f"ply\n"
|
||||
header += f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}\n"
|
||||
header += f"element vertex {num_vertices}\n"
|
||||
header += vertex_header + "\n"
|
||||
header += f"element face {num_faces}\n"
|
||||
header += face_header + "\n"
|
||||
header += "end_header\n"
|
||||
|
||||
# Open the file for writing
|
||||
with open(filename, 'wb' if not ascii else 'w') as f:
|
||||
# Write the header
|
||||
f.write(header if ascii else header.encode('utf-8'))
|
||||
|
||||
# Write the vertex data
|
||||
if ascii:
|
||||
for v in vertices:
|
||||
f.write(f"{v[0]} {v[1]} {v[2]}\n")
|
||||
else:
|
||||
for v in vertices:
|
||||
f.write(struct.pack('<fff', *v))
|
||||
|
||||
# Write the face data
|
||||
if ascii:
|
||||
for tri in tris:
|
||||
f.write(f"3 {tri[0]} {tri[1]} {tri[2]}\n")
|
||||
for quad in quads:
|
||||
f.write(f"4 {quad[0]} {quad[1]} {quad[2]} {quad[3]}\n")
|
||||
else:
|
||||
for tri in tris:
|
||||
f.write(struct.pack('<B3i', 3, *tri)) # 3 indices for triangle
|
||||
for quad in quads:
|
||||
f.write(struct.pack('<B4i', 4, *quad)) # 4 indices for quad
|
||||
|
||||
|
||||
# ==============IMAGE UTILS===============
|
||||
|
||||
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
|
||||
num_images = len(images)
|
||||
if nrow is None and ncol is None:
|
||||
if aspect_ratio is not None:
|
||||
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
|
||||
else:
|
||||
nrow = int(np.sqrt(num_images))
|
||||
ncol = (num_images + nrow - 1) // nrow
|
||||
elif nrow is None and ncol is not None:
|
||||
nrow = (num_images + ncol - 1) // ncol
|
||||
elif nrow is not None and ncol is None:
|
||||
ncol = (num_images + nrow - 1) // nrow
|
||||
else:
|
||||
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
|
||||
|
||||
if images[0].ndim == 2:
|
||||
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
|
||||
else:
|
||||
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
|
||||
for i, img in enumerate(images):
|
||||
row = i // ncol
|
||||
col = i % ncol
|
||||
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
|
||||
return grid
|
||||
|
||||
|
||||
def notes_on_image(img, notes=None):
|
||||
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
if notes is not None:
|
||||
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"):
|
||||
"""
|
||||
Draw text on an image of the given resolution. The text is automatically wrapped
|
||||
and scaled so that it fits completely within the image while preserving any explicit
|
||||
line breaks and original spacing. Horizontal and vertical alignment can be controlled
|
||||
via flags.
|
||||
|
||||
Parameters:
|
||||
text (str): The input text. Newline characters and spacing are preserved.
|
||||
resolution (tuple): The image resolution as (width, height).
|
||||
max_size (float): The maximum font size.
|
||||
h_align (str): Horizontal alignment. Options: "left", "center", "right".
|
||||
v_align (str): Vertical alignment. Options: "top", "center", "bottom".
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: The resulting image (BGR format) with the text drawn.
|
||||
"""
|
||||
width, height = resolution
|
||||
# Create a white background image
|
||||
img = np.full((height, width, 3), 255, dtype=np.uint8)
|
||||
|
||||
# Set margins and compute available drawing area
|
||||
margin = 10
|
||||
avail_width = width - 2 * margin
|
||||
avail_height = height - 2 * margin
|
||||
|
||||
# Choose OpenCV font and text thickness
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
thickness = 1
|
||||
# Ratio for additional spacing between lines (relative to the height of "A")
|
||||
line_spacing_ratio = 0.5
|
||||
|
||||
def wrap_line(line, max_width, font, thickness, scale):
|
||||
"""
|
||||
Wrap a single line of text into multiple lines such that each line's
|
||||
width (measured at the given scale) does not exceed max_width.
|
||||
This function preserves the original spacing by splitting the line into tokens
|
||||
(words and whitespace) using a regular expression.
|
||||
|
||||
Parameters:
|
||||
line (str): The input text line.
|
||||
max_width (int): Maximum allowed width in pixels.
|
||||
font (int): OpenCV font identifier.
|
||||
thickness (int): Text thickness.
|
||||
scale (float): The current font scale.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of wrapped lines.
|
||||
"""
|
||||
# Split the line into tokens (words and whitespace), preserving spacing
|
||||
tokens = re.split(r'(\s+)', line)
|
||||
if not tokens:
|
||||
return ['']
|
||||
|
||||
wrapped_lines = []
|
||||
current_line = ""
|
||||
for token in tokens:
|
||||
candidate = current_line + token
|
||||
candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0]
|
||||
if candidate_width <= max_width:
|
||||
current_line = candidate
|
||||
else:
|
||||
# If current_line is empty, the token itself is too wide;
|
||||
# break the token character by character.
|
||||
if current_line == "":
|
||||
sub_token = ""
|
||||
for char in token:
|
||||
candidate_char = sub_token + char
|
||||
if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width:
|
||||
sub_token = candidate_char
|
||||
else:
|
||||
if sub_token:
|
||||
wrapped_lines.append(sub_token)
|
||||
sub_token = char
|
||||
current_line = sub_token
|
||||
else:
|
||||
wrapped_lines.append(current_line)
|
||||
current_line = token
|
||||
if current_line:
|
||||
wrapped_lines.append(current_line)
|
||||
return wrapped_lines
|
||||
|
||||
def compute_text_block(scale):
|
||||
"""
|
||||
Wrap the entire text (splitting at explicit newline characters) using the
|
||||
provided scale, and then compute the overall width and height of the text block.
|
||||
|
||||
Returns:
|
||||
wrapped_lines (List[str]): The list of wrapped lines.
|
||||
block_width (int): Maximum width among the wrapped lines.
|
||||
block_height (int): Total height of the text block including spacing.
|
||||
sizes (List[tuple]): A list of (width, height) for each wrapped line.
|
||||
spacing (int): The spacing between lines (computed from the scaled "A" height).
|
||||
"""
|
||||
# Split text by explicit newlines
|
||||
input_lines = text.splitlines() if text else ['']
|
||||
wrapped_lines = []
|
||||
for line in input_lines:
|
||||
wrapped = wrap_line(line, avail_width, font, thickness, scale)
|
||||
wrapped_lines.extend(wrapped)
|
||||
|
||||
sizes = []
|
||||
for line in wrapped_lines:
|
||||
(text_size, _) = cv2.getTextSize(line, font, scale, thickness)
|
||||
sizes.append(text_size) # (width, height)
|
||||
|
||||
block_width = max((w for w, h in sizes), default=0)
|
||||
# Use the height of "A" (at the current scale) to compute line spacing
|
||||
base_height = cv2.getTextSize("A", font, scale, thickness)[0][1]
|
||||
spacing = int(line_spacing_ratio * base_height)
|
||||
block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0
|
||||
|
||||
return wrapped_lines, block_width, block_height, sizes, spacing
|
||||
|
||||
# Use binary search to find the maximum scale that allows the text block to fit
|
||||
lo = 0.001
|
||||
hi = max_size
|
||||
eps = 0.001 # convergence threshold
|
||||
best_scale = lo
|
||||
best_result = None
|
||||
|
||||
while hi - lo > eps:
|
||||
mid = (lo + hi) / 2
|
||||
wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid)
|
||||
# Ensure that both width and height constraints are met
|
||||
if block_width <= avail_width and block_height <= avail_height:
|
||||
best_scale = mid
|
||||
best_result = (wrapped_lines, block_width, block_height, sizes, spacing)
|
||||
lo = mid # try a larger scale
|
||||
else:
|
||||
hi = mid # reduce the scale
|
||||
|
||||
if best_result is None:
|
||||
best_scale = 0.5
|
||||
best_result = compute_text_block(best_scale)
|
||||
|
||||
wrapped_lines, block_width, block_height, sizes, spacing = best_result
|
||||
|
||||
# Compute starting y-coordinate based on vertical alignment flag
|
||||
if v_align == "top":
|
||||
y_top = margin
|
||||
elif v_align == "center":
|
||||
y_top = margin + (avail_height - block_height) // 2
|
||||
elif v_align == "bottom":
|
||||
y_top = margin + (avail_height - block_height)
|
||||
else:
|
||||
y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag
|
||||
|
||||
# For cv2.putText, the y coordinate represents the text baseline;
|
||||
# so for the first line add its height.
|
||||
y = y_top + (sizes[0][1] if sizes else 0)
|
||||
|
||||
# Draw each line with horizontal alignment based on the flag
|
||||
for i, line in enumerate(wrapped_lines):
|
||||
line_width, line_height = sizes[i]
|
||||
if h_align == "left":
|
||||
x = margin
|
||||
elif h_align == "center":
|
||||
x = margin + (avail_width - line_width) // 2
|
||||
elif h_align == "right":
|
||||
x = margin + (avail_width - line_width)
|
||||
else:
|
||||
x = margin # default to left if invalid flag
|
||||
|
||||
cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
||||
y += line_height + spacing
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def save_image_with_notes(img, path, notes=None):
|
||||
"""
|
||||
Save an image with notes.
|
||||
"""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy().transpose(1, 2, 0)
|
||||
if img.dtype == np.float32 or img.dtype == np.float64:
|
||||
img = np.clip(img * 255, 0, 255).astype(np.uint8)
|
||||
img = notes_on_image(img, notes)
|
||||
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
||||
167
data_toolkit/voxelize_pbr.py
Executable file
167
data_toolkit/voxelize_pbr.py
Executable file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
import copy
|
||||
import sys
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
from easydict import EasyDict as edict
|
||||
from functools import partial
|
||||
import o_voxel
|
||||
|
||||
|
||||
def _pbr_voxelize(file, metadatum, pbr_dump_root, root):
|
||||
sha256 = metadatum['sha256']
|
||||
try:
|
||||
pack = {'sha256': sha256}
|
||||
dump = None
|
||||
for res in opt.resolution:
|
||||
need_process = False
|
||||
|
||||
# check if already processed
|
||||
if os.path.exists(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz')):
|
||||
try:
|
||||
info = o_voxel.io.read_vxz_info(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz'))
|
||||
pack[f'pbr_voxelized_{res}'] = True
|
||||
pack[f'num_pbr_voxels_{res}'] = info['num_voxel']
|
||||
except Exception as e:
|
||||
print(f'Error reading {sha256}.vxz: {e}')
|
||||
need_process = True
|
||||
else:
|
||||
need_process = True
|
||||
|
||||
# process if necessary
|
||||
if need_process:
|
||||
if dump == None:
|
||||
with open(os.path.join(pbr_dump_root, 'pbr_dumps', f'{sha256}.pickle'), 'rb') as f:
|
||||
dump = pickle.load(f)
|
||||
# Fix dump alpha map
|
||||
for mat in dump['materials']:
|
||||
if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE':
|
||||
mat['alphaMode'] = 'BLEND'
|
||||
dump['materials'].append({
|
||||
"baseColorFactor": [0.8, 0.8, 0.8],
|
||||
"alphaFactor": 1.0,
|
||||
"metallicFactor": 0.0,
|
||||
"roughnessFactor": 0.5,
|
||||
"alphaMode": "OPAQUE",
|
||||
"alphaCutoff": 0.5,
|
||||
"baseColorTexture": None,
|
||||
"alphaTexture": None,
|
||||
"metallicTexture": None,
|
||||
"roughnessTexture": None,
|
||||
}) # append default material
|
||||
dump['objects'] = [
|
||||
obj for obj in dump['objects']
|
||||
if obj['vertices'].size != 0 and obj['faces'].size != 0
|
||||
]
|
||||
vertices = torch.from_numpy(np.concatenate([obj['vertices'] for obj in dump['objects']], axis=0)).float()
|
||||
vertices_min = vertices.min(dim=0)[0]
|
||||
vertices_max = vertices.max(dim=0)[0]
|
||||
center = (vertices_min + vertices_max) / 2
|
||||
scale = 0.99999 / (vertices_max - vertices_min).max()
|
||||
for obj in dump['objects']:
|
||||
obj['vertices'] = (torch.from_numpy(obj['vertices']).float() - center) * scale
|
||||
obj['vertices'] = obj['vertices'].numpy()
|
||||
obj['mat_ids'][obj['mat_ids'] == -1] = len(dump['materials']) - 1
|
||||
assert np.all(obj['mat_ids'] >= 0), 'invalid mat_ids'
|
||||
assert np.all(obj['vertices'] >= -0.5) and np.all(obj['vertices'] <= 0.5), 'vertices out of range'
|
||||
|
||||
coord, attr = o_voxel.convert.blender_dump_to_volumetric_attr(dump, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
||||
mip_level_offset=0, verbose=False, timing=False)
|
||||
del attr['normal']
|
||||
del attr['emissive']
|
||||
o_voxel.io.write_vxz(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz'), coord, attr)
|
||||
pack[f'pbr_voxelized_{res}'] = True
|
||||
pack[f'num_pbr_voxels_{res}'] = len(coord)
|
||||
|
||||
return pack
|
||||
except Exception as e:
|
||||
print(f'Error voxelizing {sha256}: {e}')
|
||||
return {'sha256': sha256, 'error': str(e)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--pbr_dump_root', type=str, default=None,
|
||||
help='Directory to load mesh dumps')
|
||||
parser.add_argument('--pbr_voxel_root', type=str, default=None,
|
||||
help='Directory to save voxelized pbr attributes')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
dataset_utils.add_args(parser)
|
||||
parser.add_argument('--resolution', type=str, default=1024)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
parser.add_argument('--max_workers', type=int, default=0)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
opt.resolution = sorted([int(x) for x in opt.resolution.split(',')], reverse=True)
|
||||
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
|
||||
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
|
||||
|
||||
for res in opt.resolution:
|
||||
os.makedirs(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'new_records'), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
|
||||
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
|
||||
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
|
||||
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
|
||||
for res in opt.resolution:
|
||||
if os.path.exists(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'metadata.csv')):
|
||||
pbr_voxel_metadata = pd.read_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}','metadata.csv')).set_index('sha256')
|
||||
pbr_voxel_metadata = pbr_voxel_metadata.rename(columns={'pbr_voxelized': f'pbr_voxelized_{res}', 'num_pbr_voxels': f'num_pbr_voxels_{res}'})
|
||||
metadata = metadata.combine_first(pbr_voxel_metadata)
|
||||
metadata = metadata.reset_index()
|
||||
if opt.instances is None:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
metadata = metadata[metadata['pbr_dumped'] == True]
|
||||
mask = np.zeros(len(metadata), dtype=bool)
|
||||
for res in opt.resolution:
|
||||
if f'pbr_voxelized_{res}' in metadata.columns:
|
||||
mask |= metadata[f'pbr_voxelized_{res}'] != True
|
||||
else:
|
||||
mask[:] = True
|
||||
break
|
||||
metadata = metadata[mask]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
func = partial(_pbr_voxelize, pbr_dump_root=opt.pbr_dump_root, root=opt.pbr_voxel_root)
|
||||
pbr_voxelized = dataset_utils.foreach_instance(metadata, None, func, max_workers=opt.max_workers, no_file=True, desc='Voxelizing')
|
||||
if 'error' in pbr_voxelized.columns:
|
||||
errors = pbr_voxelized[pbr_voxelized['error'].notna()]
|
||||
with open('errors.txt', 'w') as f:
|
||||
f.write('\n'.join(errors['sha256'].tolist()))
|
||||
for res in opt.resolution:
|
||||
if f'pbr_voxelized_{res}' in pbr_voxelized.columns:
|
||||
pbr_voxel_metadata = pbr_voxelized[pbr_voxelized[f'pbr_voxelized_{res}'] == True]
|
||||
if len(pbr_voxel_metadata) > 0:
|
||||
pbr_voxel_metadata = pbr_voxel_metadata[['sha256', f'pbr_voxelized_{res}', f'num_pbr_voxels_{res}']]
|
||||
pbr_voxel_metadata = pbr_voxel_metadata.rename(columns={f'pbr_voxelized_{res}': 'pbr_voxelized', f'num_pbr_voxels_{res}': 'num_pbr_voxels'})
|
||||
pbr_voxel_metadata.to_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'new_records', f'part_{opt.rank}.csv'), index=False)
|
||||
158
train.py
Normal file
158
train.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import argparse
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from trellis2 import models, datasets, trainers
|
||||
from trellis2.utils.dist_utils import setup_dist
|
||||
|
||||
|
||||
def find_ckpt(cfg):
|
||||
# Load checkpoint
|
||||
cfg['load_ckpt'] = None
|
||||
if cfg.load_dir != '':
|
||||
if cfg.ckpt == 'latest':
|
||||
files = glob.glob(os.path.join(cfg.load_dir, 'ckpts', 'misc_*.pt'))
|
||||
if len(files) != 0:
|
||||
cfg.load_ckpt = max([
|
||||
int(os.path.basename(f).split('step')[-1].split('.')[0])
|
||||
for f in files
|
||||
])
|
||||
elif cfg.ckpt == 'none':
|
||||
cfg.load_ckpt = None
|
||||
else:
|
||||
cfg.load_ckpt = int(cfg.ckpt)
|
||||
return cfg
|
||||
|
||||
|
||||
def setup_rng(rank):
|
||||
torch.manual_seed(rank)
|
||||
torch.cuda.manual_seed_all(rank)
|
||||
np.random.seed(rank)
|
||||
random.seed(rank)
|
||||
|
||||
|
||||
def get_model_summary(model):
|
||||
model_summary = 'Parameters:\n'
|
||||
model_summary += '=' * 128 + '\n'
|
||||
model_summary += f'{"Name":<{72}}{"Shape":<{32}}{"Type":<{16}}{"Grad"}\n'
|
||||
num_params = 0
|
||||
num_trainable_params = 0
|
||||
for name, param in model.named_parameters():
|
||||
model_summary += f'{name:<{72}}{str(param.shape):<{32}}{str(param.dtype):<{16}}{param.requires_grad}\n'
|
||||
num_params += param.numel()
|
||||
if param.requires_grad:
|
||||
num_trainable_params += param.numel()
|
||||
model_summary += '\n'
|
||||
model_summary += f'Number of parameters: {num_params}\n'
|
||||
model_summary += f'Number of trainable parameters: {num_trainable_params}\n'
|
||||
return model_summary
|
||||
|
||||
|
||||
def main(local_rank, cfg):
|
||||
# Set up distributed training
|
||||
rank = cfg.node_rank * cfg.num_gpus + local_rank
|
||||
world_size = cfg.num_nodes * cfg.num_gpus
|
||||
if world_size > 1:
|
||||
setup_dist(rank, local_rank, world_size, cfg.master_addr, cfg.master_port)
|
||||
|
||||
# Seed rngs
|
||||
setup_rng(rank)
|
||||
|
||||
# Load data
|
||||
dataset = getattr(datasets, cfg.dataset.name)(cfg.data_dir, **cfg.dataset.args)
|
||||
|
||||
# Build model
|
||||
model_dict = {
|
||||
name: getattr(models, model.name)(**model.args).cuda()
|
||||
for name, model in cfg.models.items()
|
||||
}
|
||||
|
||||
# Model summary
|
||||
if rank == 0:
|
||||
for name, backbone in model_dict.items():
|
||||
model_summary = get_model_summary(backbone)
|
||||
print(f'\n\nBackbone: {name}\n' + model_summary)
|
||||
with open(os.path.join(cfg.output_dir, f'{name}_model_summary.txt'), 'w') as fp:
|
||||
print(model_summary, file=fp)
|
||||
|
||||
# Build trainer
|
||||
trainer = getattr(trainers, cfg.trainer.name)(model_dict, dataset, **cfg.trainer.args, output_dir=cfg.output_dir, load_dir=cfg.load_dir, step=cfg.load_ckpt)
|
||||
|
||||
# Train
|
||||
if not cfg.tryrun:
|
||||
if cfg.profile:
|
||||
trainer.profile()
|
||||
else:
|
||||
trainer.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Arguments and config
|
||||
parser = argparse.ArgumentParser()
|
||||
## config
|
||||
parser.add_argument('--config', type=str, required=True, help='Experiment config file')
|
||||
## io and resume
|
||||
parser.add_argument('--output_dir', type=str, required=True, help='Output directory')
|
||||
parser.add_argument('--load_dir', type=str, default='', help='Load directory, default to output_dir')
|
||||
parser.add_argument('--ckpt', type=str, default='latest', help='Checkpoint step to resume training, default to latest')
|
||||
parser.add_argument('--data_dir', type=str, default='./data/', help='Data directory')
|
||||
parser.add_argument('--auto_retry', type=int, default=3, help='Number of retries on error')
|
||||
## dubug
|
||||
parser.add_argument('--tryrun', action='store_true', help='Try run without training')
|
||||
parser.add_argument('--profile', action='store_true', help='Profile training')
|
||||
## multi-node and multi-gpu
|
||||
parser.add_argument('--num_nodes', type=int, default=1, help='Number of nodes')
|
||||
parser.add_argument('--node_rank', type=int, default=0, help='Node rank')
|
||||
parser.add_argument('--num_gpus', type=int, default=-1, help='Number of GPUs per node, default to all')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost', help='Master address for distributed training')
|
||||
parser.add_argument('--master_port', type=str, default='12345', help='Port for distributed training')
|
||||
opt = parser.parse_args()
|
||||
opt.load_dir = opt.load_dir if opt.load_dir != '' else opt.output_dir
|
||||
opt.num_gpus = torch.cuda.device_count() if opt.num_gpus == -1 else opt.num_gpus
|
||||
## Load config
|
||||
config = json.load(open(opt.config, 'r'))
|
||||
## Combine arguments and config
|
||||
cfg = edict()
|
||||
cfg.update(opt.__dict__)
|
||||
cfg.update(config)
|
||||
print('\n\nConfig:')
|
||||
print('=' * 80)
|
||||
print(json.dumps(cfg.__dict__, indent=4))
|
||||
|
||||
# Prepare output directory
|
||||
if cfg.node_rank == 0:
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
## Save command and config
|
||||
with open(os.path.join(cfg.output_dir, 'command.txt'), 'w') as fp:
|
||||
print(' '.join(['python'] + sys.argv), file=fp)
|
||||
with open(os.path.join(cfg.output_dir, 'config.json'), 'w') as fp:
|
||||
json.dump(config, fp, indent=4)
|
||||
|
||||
# Run
|
||||
if cfg.auto_retry == 0:
|
||||
cfg = find_ckpt(cfg)
|
||||
if cfg.num_gpus > 1:
|
||||
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
|
||||
else:
|
||||
main(0, cfg)
|
||||
else:
|
||||
for rty in range(cfg.auto_retry):
|
||||
try:
|
||||
cfg = find_ckpt(cfg)
|
||||
if cfg.num_gpus > 1:
|
||||
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
|
||||
else:
|
||||
main(0, cfg)
|
||||
break
|
||||
except Exception as e:
|
||||
print(f'Error: {e}')
|
||||
print(f'Retrying ({rty + 1}/{cfg.auto_retry})...')
|
||||
|
||||
@@ -10,10 +10,10 @@ import utils3d
|
||||
from .components import StandardDatasetBase
|
||||
from ..modules import sparse as sp
|
||||
from ..renderers import VoxelRenderer
|
||||
from ..representations.mesh import Voxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
|
||||
from ..representations import Voxel
|
||||
from ..representations.mesh import MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
|
||||
|
||||
from ..utils.data_utils import load_balanced_group_indices
|
||||
from ..utils.mesh_utils import subdivide_to_size
|
||||
|
||||
|
||||
def is_power_of_two(n: int) -> bool:
|
||||
@@ -71,7 +71,7 @@ class SparseVoxelPbrVisMixin:
|
||||
origin=[-0.5, -0.5, -0.5],
|
||||
voxel_size=1/self.resolution,
|
||||
coords=x[i].coords[:, 1:].contiguous(),
|
||||
attrs=attr,
|
||||
attrs=None,
|
||||
layout={
|
||||
'color': slice(0, 3),
|
||||
}
|
||||
@@ -81,7 +81,7 @@ class SparseVoxelPbrVisMixin:
|
||||
tile = [2, 2]
|
||||
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
||||
attr = x[i].feats[:, self.layout[k]].expand(-1, 3)
|
||||
res = renderer.render(rep, ext, intr)
|
||||
res = renderer.render(rep, ext, intr, colors_overwrite=attr)
|
||||
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
||||
images[k].append(image)
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@ class SLatShapeVisMixin(SLatVisMixin):
|
||||
return
|
||||
if self.slat_dec_path is not None:
|
||||
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
|
||||
cfg['models']['decoder']['args']['resolution'] = self.resolution
|
||||
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
|
||||
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
|
||||
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
|
||||
else:
|
||||
decoder = models.from_pretrained(self.pretrained_slat_dec)
|
||||
decoder.set_resolution(self.resolution)
|
||||
self.slat_dec = decoder.cuda().eval()
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -72,7 +72,7 @@ class SLatShape(SLatShapeVisMixin, SLat):
|
||||
min_aesthetic_score: float = 5.0,
|
||||
max_tokens: int = 32768,
|
||||
normalization: Optional[dict] = None,
|
||||
pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
|
||||
pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
|
||||
slat_dec_path: Optional[str] = None,
|
||||
slat_dec_ckpt: Optional[str] = None,
|
||||
):
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import os
|
||||
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
||||
import json
|
||||
from typing import *
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
from .. import models
|
||||
from .components import StandardDatasetBase, ImageConditionedMixin
|
||||
from ..modules.sparse import SparseTensor, sparse_cat
|
||||
from ..representations import MeshWithVoxel
|
||||
from ..renderers import PbrMeshRenderer, EnvMap
|
||||
from ..utils.data_utils import load_balanced_group_indices
|
||||
from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics
|
||||
from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics
|
||||
|
||||
|
||||
class SLatPbrVisMixin:
|
||||
@@ -47,12 +50,12 @@ class SLatPbrVisMixin:
|
||||
|
||||
if self.shape_slat_dec_path is not None:
|
||||
cfg = json.load(open(os.path.join(self.shape_slat_dec_path, 'config.json'), 'r'))
|
||||
cfg['models']['decoder']['args']['resolution'] = self.resolution
|
||||
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
|
||||
ckpt_path = os.path.join(self.shape_slat_dec_path, 'ckpts', f'decoder_{self.shape_slat_dec_ckpt}.pt')
|
||||
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
|
||||
else:
|
||||
decoder = models.from_pretrained(self.pretrained_shape_slat_dec)
|
||||
decoder.set_resolution(self.resolution)
|
||||
self.shape_slat_dec = decoder.cuda().eval()
|
||||
|
||||
def _delete_slat_dec(self):
|
||||
@@ -71,7 +74,7 @@ class SLatPbrVisMixin:
|
||||
z = z * self.pbr_slat_std.to(z.device) + self.pbr_slat_mean.to(z.device)
|
||||
for i in range(0, z.shape[0], batch_size):
|
||||
mesh, subs = self.shape_slat_dec(shape_z[i:i+batch_size], return_subs=True)
|
||||
vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs)
|
||||
vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs) * 0.5 + 0.5
|
||||
reps.extend([
|
||||
MeshWithVoxel(
|
||||
m.vertices, m.faces,
|
||||
@@ -101,18 +104,32 @@ class SLatPbrVisMixin:
|
||||
exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
|
||||
|
||||
# render
|
||||
renderer = get_renderer(reps[0])
|
||||
images = {k: [] for k in self.layout}
|
||||
renderer = PbrMeshRenderer()
|
||||
renderer.rendering_options.resolution = 512
|
||||
renderer.rendering_options.near = 1
|
||||
renderer.rendering_options.far = 100
|
||||
renderer.rendering_options.ssaa = 2
|
||||
renderer.rendering_options.peel_layers = 8
|
||||
envmap = EnvMap(torch.tensor(
|
||||
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
|
||||
dtype=torch.float32, device='cuda'
|
||||
))
|
||||
|
||||
images = {}
|
||||
for representation in reps:
|
||||
image = {k: torch.zeros(3, 1024, 1024).cuda() for k in self.layout}
|
||||
image = {}
|
||||
tile = [2, 2]
|
||||
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
||||
res = renderer.render(representation, ext, intr, return_types=['attr'])
|
||||
for k in self.layout:
|
||||
image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res[k]
|
||||
for k in self.layout:
|
||||
res = renderer.render(representation, ext, intr, envmap=envmap)
|
||||
for k, v in res.items():
|
||||
if k not in images:
|
||||
images[k] = []
|
||||
if k not in image:
|
||||
image[k] = torch.zeros(3, 1024, 1024).cuda()
|
||||
image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = v
|
||||
for k in images.keys():
|
||||
images[k].append(image[k])
|
||||
for k in self.layout:
|
||||
for k in images.keys():
|
||||
images[k] = torch.stack(images[k], dim=0)
|
||||
return images
|
||||
|
||||
@@ -156,7 +173,7 @@ class SLatPbr(SLatPbrVisMixin, StandardDatasetBase):
|
||||
self.min_aesthetic_score = min_aesthetic_score
|
||||
self.max_tokens = max_tokens
|
||||
self.full_pbr = full_pbr
|
||||
self.value_range = (-1, 1)
|
||||
self.value_range = (0, 1)
|
||||
|
||||
super().__init__(
|
||||
roots,
|
||||
|
||||
@@ -87,8 +87,8 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
||||
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
|
||||
intersected_logits = h.replace(h.feats[..., 3:6])
|
||||
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
|
||||
mesh = [Mesh(flexible_dual_grid_to_mesh(
|
||||
h.coords[:, 1:], v.feats, i.feats, q.feats,
|
||||
mesh = [Mesh(*flexible_dual_grid_to_mesh(
|
||||
v.coords[:, 1:], v.feats, i.feats, q.feats,
|
||||
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
||||
grid_size=self.resolution,
|
||||
train=True
|
||||
@@ -101,7 +101,7 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
||||
intersected = h.replace(h.feats[..., 3:6] > 0)
|
||||
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
|
||||
mesh = [Mesh(*flexible_dual_grid_to_mesh(
|
||||
h.coords[:, 1:], v.feats, i.feats, q.feats,
|
||||
v.coords[:, 1:], v.feats, i.feats, q.feats,
|
||||
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
||||
grid_size=self.resolution,
|
||||
train=False
|
||||
|
||||
@@ -250,9 +250,19 @@ class PbrMeshRenderer:
|
||||
ssaa = self.rendering_options["ssaa"]
|
||||
|
||||
if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
|
||||
return edict(
|
||||
shaded=torch.full((4, resolution, resolution), 0.5, dtype=torch.float32, device=self.device),
|
||||
out_dict = edict(
|
||||
normal=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
|
||||
mask=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
|
||||
base_color=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
|
||||
metallic=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
|
||||
roughness=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
|
||||
alpha=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
|
||||
clay=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
|
||||
)
|
||||
for i, k in enumerate(envmap.keys()):
|
||||
shaded_key = f"shaded_{k}" if k != '' else "shaded"
|
||||
out_dict[shaded_key] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device)
|
||||
return out_dict
|
||||
|
||||
rays_o, rays_d = utils3d.torch.get_image_rays(
|
||||
extrinsics, intrinsics, resolution * ssaa, resolution * ssaa
|
||||
|
||||
@@ -138,6 +138,7 @@ class ImageConditionedMixin:
|
||||
"""
|
||||
with dist_utils.local_master_first():
|
||||
self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {}))
|
||||
self.image_cond_model.cuda()
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user