diff --git a/README.md b/README.md index a83ae5dbaa0e..c1ae855ae12f 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ In Computer Vision: - [Panoptic Segmentation with MaskFormer](https://huggingface.co/facebook/maskformer-swin-small-coco) - [Depth Estimation with DPT](https://huggingface.co/docs/transformers/model_doc/dpt) - [Video Classification with VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae) +- [Universal Segmentation with OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_dinat_large) In Audio: - [Automatic Speech Recognition with Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) @@ -371,6 +372,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](https://huggingface.co/docs/transformers/main/model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/README_es.md b/README_es.md index c66f4b77b2e4..d761f7846b20 100644 --- a/README_es.md +++ b/README_es.md @@ -92,6 +92,7 @@ En visión de ordenador: - [Detección de objetos con DETR](https://huggingface.co/facebook/detr-resnet-50) - [Segmentación semántica con SegFormer](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512) - [Segmentación panóptica con DETR](https://huggingface.co/facebook/detr-resnet-50-panoptic) +- [Segmentación Universal con OneFormer (Segmentación Semántica, de Instancia y Panóptica con un solo modelo)](https://huggingface.co/shi-labs/oneformer_ade20k_dinat_large) En Audio: - [Reconocimiento de voz automático con Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) @@ -364,6 +365,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](https://huggingface.co/docs/transformers/main/model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/README_hd.md b/README_hd.md index 31f6b448cdde..cd2ef6d8512e 100644 --- a/README_hd.md +++ b/README_hd.md @@ -337,6 +337,7 @@ conda install -c huggingface transformers 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (हुआवेई नूह के आर्क लैब से) साथ में कागज़ [NEZHA: चीनी भाषा समझ के लिए तंत्रिका प्रासंगिक प्रतिनिधित्व](https :/ /arxiv.org/abs/1909.00204) जुन्किउ वेई, ज़ियाओज़े रेन, ज़िआओगुआंग ली, वेनयोंग हुआंग, यी लियाओ, याशेंग वांग, जियाशू लिन, शिन जियांग, जिओ चेन और कुन लियू द्वारा। 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (फ्रॉम मेटा) साथ में पेपर [नो लैंग्वेज लेफ्ट बिहाइंड: स्केलिंग ह्यूमन-सेंटेड मशीन ट्रांसलेशन] (https://arxiv.org/abs/2207.04672) एनएलएलबी टीम द्वारा प्रकाशित। 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (विस्कॉन्सिन विश्वविद्यालय - मैडिसन से) साथ में कागज [Nyströmformer: A Nyström- आधारित एल्गोरिथम आत्म-ध्यान का अनुमान लगाने के लिए ](https://arxiv.org/abs/2102.03902) युनयांग ज़िओंग, झानपेंग ज़ेंग, रुद्रसिस चक्रवर्ती, मिंगक्सिंग टैन, ग्लेन फंग, यिन ली, विकास सिंह द्वारा पोस्ट किया गया। +1. **[OneFormer](https://huggingface.co/docs/transformers/main/model_doc/oneformer)** (SHI Labs से) पेपर [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) जितेश जैन, जिआचेन ली, मांगटिक चिउ, अली हसनी, निकिता ओरलोव, हम्फ्री शि के द्वारा जारी किया गया है। 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI से) साथ में कागज [विज़न ट्रांसफॉर्मर्स के साथ सिंपल ओपन-वोकैबुलरी ऑब्जेक्ट डिटेक्शन](https:/ /arxiv.org/abs/2205.06230) मैथियास मिंडरर, एलेक्सी ग्रिट्सेंको, ऑस्टिन स्टोन, मैक्सिम न्यूमैन, डिर्क वीसेनबोर्न, एलेक्सी डोसोवित्स्की, अरविंद महेंद्रन, अनुराग अर्नब, मुस्तफा देहघानी, ज़ुओरन शेन, जिओ वांग, ज़ियाओहुआ झाई, थॉमस किफ़, और नील हॉल्सबी द्वारा पोस्ट किया गया। 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/README_ja.md b/README_ja.md index 9120113dc31e..2213cb09f85e 100644 --- a/README_ja.md +++ b/README_ja.md @@ -399,6 +399,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab から) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu から公開された研究論文: [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (Meta から) the NLLB team から公開された研究論文: [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (the University of Wisconsin - Madison から) Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh から公開された研究論文: [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) +1. **[OneFormer](https://huggingface.co/docs/transformers/main/model_doc/oneformer)** (SHI Labs から) Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi から公開された研究論文: [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI から) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al から公開された研究論文: [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI から) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby から公開された研究論文: [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google から) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu から公開された研究論文: [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) diff --git a/README_ko.md b/README_ko.md index c93eb28415e0..bef63246d7a8 100644 --- a/README_ko.md +++ b/README_ko.md @@ -314,6 +314,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab 에서) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 의 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 논문과 함께 발표했습니다. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (Meta 에서) the NLLB team 의 [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) 논문과 함께 발표했습니다. 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (the University of Wisconsin - Madison 에서) Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh 의 [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) 논문과 함께 발표했습니다. +1. **[OneFormer](https://huggingface.co/docs/transformers/main/model_doc/oneformer)** (SHI Labs 에서) Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi 의 [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) 논문과 함께 발표했습니다. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI 에서) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 의 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 논문과 함께 발표했습니다. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI 에서) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 의 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 논문과 함께 발표했습니다. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google 에서) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 의 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 18cdbc015599..4f3e82040dd3 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -338,6 +338,7 @@ conda install -c huggingface transformers 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (来自华为诺亚方舟实验室) 伴随论文 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 由 Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 发布。 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (来自 Meta) 伴随论文 [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) 由 the NLLB team 发布。 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (来自 the University of Wisconsin - Madison) 伴随论文 [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) 由 Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh 发布。 +1. **[OneFormer](https://huggingface.co/docs/transformers/main/model_doc/oneformer)** (来自 SHI Labs) 伴随论文 [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) 由 Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi 发布。 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (来自 Meta AI) 伴随论文 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 由 Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 发布。 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (来自 Google AI) 伴随论文 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 由 Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 发布。 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index a4de0541fd7c..cde009a859fe 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -350,6 +350,7 @@ conda install -c huggingface transformers 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](https://huggingface.co/docs/transformers/main/model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/docs/source/de/index.mdx b/docs/source/de/index.mdx index 031df91237f1..7efc14ad1a03 100644 --- a/docs/source/de/index.mdx +++ b/docs/source/de/index.mdx @@ -130,6 +130,7 @@ Die Bibliothek enthält derzeit JAX-, PyTorch- und TensorFlow-Implementierungen, 1. **[Nezha](model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. 1. **[NLLB](model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[OPT](master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 51dcd44b3c21..245eef17cdc4 100755 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -538,6 +538,8 @@ title: LayoutXLM - local: model_doc/lxmert title: LXMERT + - local: model_doc/oneformer + title: OneFormer - local: model_doc/owlvit title: OWL-ViT - local: model_doc/perceiver diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 335d26ebbbb7..016bc0b34e78 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -151,6 +151,7 @@ The documentation is organized into five sections: 1. **[Nezha](model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. 1. **[NLLB](model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[OPT](master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. @@ -322,6 +323,7 @@ Flax), PyTorch, and/or TensorFlow. | NAT | ❌ | ❌ | ✅ | ❌ | ❌ | | Nezha | ❌ | ❌ | ✅ | ❌ | ❌ | | Nyströmformer | ❌ | ❌ | ✅ | ❌ | ❌ | +| OneFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ | | OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ | | OPT | ❌ | ❌ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/oneformer.mdx b/docs/source/en/model_doc/oneformer.mdx new file mode 100644 index 000000000000..85b40ea80de6 --- /dev/null +++ b/docs/source/en/model_doc/oneformer.mdx @@ -0,0 +1,72 @@ + + +# OneFormer + +## Overview + +The OneFormer model was proposed in [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. OneFormer is a universal image segmentation framework that can be trained on a single panoptic dataset to perform semantic, instance, and panoptic segmentation tasks. OneFormer uses a task token to condition the model on the task in focus, making the architecture task-guided for training, and task-dynamic for inference. + + + +The abstract from the paper is the following: + +*Universal Image Segmentation is not a new concept. Past attempts to unify image segmentation in the last decades include scene parsing, panoptic segmentation, and, more recently, new panoptic architectures. However, such panoptic architectures do not truly unify image segmentation because they need to be trained individually on the semantic, instance, or panoptic segmentation to achieve the best performance. Ideally, a truly universal framework should be trained only once and achieve SOTA performance across all three image segmentation tasks. To that end, we propose OneFormer, a universal image segmentation framework that unifies segmentation with a multi-task train-once design. We first propose a task-conditioned joint training strategy that enables training on ground truths of each domain (semantic, instance, and panoptic segmentation) within a single multi-task training process. Secondly, we introduce a task token to condition our model on the task at hand, making our model task-dynamic to support multi-task training and inference. Thirdly, we propose using a query-text contrastive loss during training to establish better inter-task and inter-class distinctions. Notably, our single OneFormer model outperforms specialized Mask2Former models across all three segmentation tasks on ADE20k, CityScapes, and COCO, despite the latter being trained on each of the three tasks individually with three times the resources. With new ConvNeXt and DiNAT backbones, we observe even more performance improvement. We believe OneFormer is a significant step towards making image segmentation more universal and accessible.* + +Tips: +- OneFormer requires two inputs during inference: *image* and *task token*. +- During training, OneFormer only uses panoptic annotations. +- If you want to train the model in a distributed environment across multiple nodes, then one should update the + `get_num_masks` function inside in the `OneFormerLoss` class of `modeling_oneformer.py`. When training on multiple nodes, this should be + set to the average number of target masks across all nodes, as can be seen in the original implementation [here](https://github.com/SHI-Labs/OneFormer/blob/33ebb56ed34f970a30ae103e786c0cb64c653d9a/oneformer/modeling/criterion.py#L287). +- One can use [`OneFormerProcessor`] to prepare input images and task inputs for the model and optional targets for the model. [`OneformerProcessor`] wraps [`OneFormerImageProcessor`] and [`CLIPTokenizer`] into a single instance to both prepare the images and encode the task inputs. +- To get the final segmentation, depending on the task, you can call [`~OneFormerProcessor.post_process_semantic_segmentation`] or [`~OneFormerImageProcessor.post_process_instance_segmentation`] or [`~OneFormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`OneFormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together. + +The figure below illustrates the architecture of OneFormer. Taken from the [original paper](https://arxiv.org/abs/2211.06220). + + + +This model was contributed by [Jitesh Jain](https://huggingface.co/praeclarumjj3). The original code can be found [here](https://github.com/SHI-Labs/OneFormer). + +## OneFormer specific outputs + +[[autodoc]] models.oneformer.modeling_oneformer.OneFormerModelOutput + +[[autodoc]] models.oneformer.modeling_oneformer.OneFormerForUniversalSegmentationOutput + +## OneFormerConfig + +[[autodoc]] OneFormerConfig + +## OneFormerImageProcessor + +[[autodoc]] OneFormerImageProcessor + - preprocess + - encode_inputs + - post_process_semantic_segmentation + - post_process_instance_segmentation + - post_process_panoptic_segmentation + +## OneFormerProcessor + +[[autodoc]] OneFormerProcessor + +## OneFormerModel + +[[autodoc]] OneFormerModel + - forward + +## OneFormerForUniversalSegmentation + +[[autodoc]] OneFormerForUniversalSegmentation + - forward + \ No newline at end of file diff --git a/docs/source/es/index.mdx b/docs/source/es/index.mdx index 30ba547832d1..aca5d9f705d2 100644 --- a/docs/source/es/index.mdx +++ b/docs/source/es/index.mdx @@ -109,6 +109,7 @@ La biblioteca actualmente contiene implementaciones de JAX, PyTorch y TensorFlow 1. **[MPNet](model_doc/mpnet)** (de Microsoft Research) publicado con el paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) por Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](model_doc/mt5)** (de Google AI) publicado con el paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) por Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[Nyströmformer](model_doc/nystromformer)** (de la Universidad de Wisconsin - Madison) publicado con el paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) por Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](model_doc/oneformer)** (de la SHI Labs) publicado con el paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) por Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[Pegasus](model_doc/pegasus)** (de Google) publicado con el paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) por Jingqing Zhang, Yao Zhao, Mohammad Saleh y Peter J. Liu. 1. **[Perceiver IO](model_doc/perceiver)** (de Deepmind) publicado con el paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) por Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. 1. **[PhoBERT](model_doc/phobert)** (de VinAI Research) publicado con el paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) por Dat Quoc Nguyen y Anh Tuan Nguyen. diff --git a/docs/source/it/index.mdx b/docs/source/it/index.mdx index aa1d7e25d6d9..38ab8f8aa64c 100644 --- a/docs/source/it/index.mdx +++ b/docs/source/it/index.mdx @@ -118,6 +118,7 @@ La libreria attualmente contiene implementazioni in JAX, PyTorch e TensorFlow, p 1. **[MPNet](model_doc/mpnet)** (da Microsoft Research) rilasciato con il paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) da Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](model_doc/mt5)** (da Google AI) rilasciato con il paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) da Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[Nyströmformer](model_doc/nystromformer)** (dalla Università del Wisconsin - Madison) rilasciato con il paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) da Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](model_doc/oneformer)** (da SHI Labs) rilasciato con il paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) da Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[OPT](master/model_doc/opt)** (da Meta AI) rilasciato con il paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) da Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[Pegasus](model_doc/pegasus)** (da Google) rilasciato con il paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) da Jingqing Zhang, Yao Zhao, Mohammad Saleh e Peter J. Liu. 1. **[Perceiver IO](model_doc/perceiver)** (da Deepmind) rilasciato con il paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) da Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. diff --git a/docs/source/ko/index.mdx b/docs/source/ko/index.mdx index 4ae17ab7becd..3947982e4d5a 100644 --- a/docs/source/ko/index.mdx +++ b/docs/source/ko/index.mdx @@ -139,6 +139,7 @@ specific language governing permissions and limitations under the License. 1. **[Nezha](model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. 1. **[NLLB](model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[OPT](master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OWL-ViT](model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/docs/source/pt/index.mdx b/docs/source/pt/index.mdx index 7f1d26b6c447..ff841d91f495 100644 --- a/docs/source/pt/index.mdx +++ b/docs/source/pt/index.mdx @@ -123,6 +123,7 @@ Atualmente a biblioteca contém implementações do PyTorch, TensorFlow e JAX, p 1. **[MPNet](model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[Nyströmformer](model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. +1. **[OneFormer](model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi. 1. **[Pegasus](model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Perceiver IO](model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. 1. **[PhoBERT](model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e3b675024627..3d7fd017a20c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -348,6 +348,7 @@ "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "NystromformerConfig", ], + "models.oneformer": ["ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "OneFormerConfig", "OneFormerProcessor"], "models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"], "models.opt": ["OPTConfig"], "models.owlvit": [ @@ -799,6 +800,7 @@ _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"]) _import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"]) + _import_structure["models.oneformer"].extend(["OneFormerImageProcessor"]) _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"]) _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"]) _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) @@ -1871,6 +1873,14 @@ "NystromformerPreTrainedModel", ] ) + _import_structure["models.oneformer"].extend( + [ + "ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "OneFormerForUniversalSegmentation", + "OneFormerModel", + "OneFormerPreTrainedModel", + ] + ) _import_structure["models.openai"].extend( [ "OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3729,6 +3739,7 @@ from .models.nat import NAT_PRETRAINED_CONFIG_ARCHIVE_MAP, NatConfig from .models.nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig from .models.nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig + from .models.oneformer import ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, OneFormerConfig, OneFormerProcessor from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer from .models.opt import OPTConfig from .models.owlvit import ( @@ -4122,6 +4133,7 @@ from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor + from .models.oneformer import OneFormerImageProcessor from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor from .models.poolformer import PoolFormerFeatureExtractor, PoolFormerImageProcessor @@ -5000,6 +5012,12 @@ NystromformerModel, NystromformerPreTrainedModel, ) + from .models.oneformer import ( + ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + OneFormerForUniversalSegmentation, + OneFormerModel, + OneFormerPreTrainedModel, + ) from .models.openai import ( OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, OpenAIGPTDoubleHeadsModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index cf30880faa1c..f7bcb026e570 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -122,6 +122,7 @@ nezha, nllb, nystromformer, + oneformer, openai, opt, owlvit, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index cc3d48fe3be8..d64432b76cf1 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -120,6 +120,7 @@ ("nat", "NatConfig"), ("nezha", "NezhaConfig"), ("nystromformer", "NystromformerConfig"), + ("oneformer", "OneFormerConfig"), ("openai-gpt", "OpenAIGPTConfig"), ("opt", "OPTConfig"), ("owlvit", "OwlViTConfig"), @@ -276,6 +277,7 @@ ("nat", "NAT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("nezha", "NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("oneformer", "ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("owlvit", "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -447,6 +449,7 @@ ("nezha", "Nezha"), ("nllb", "NLLB"), ("nystromformer", "Nyströmformer"), + ("oneformer", "OneFormer"), ("openai-gpt", "OpenAI GPT"), ("opt", "OPT"), ("owlvit", "OWL-ViT"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 3b057fa2a8a7..a957abd0c083 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -69,6 +69,7 @@ ("mobilevit", "MobileViTImageProcessor"), ("mobilevit", "MobileViTImageProcessor"), ("nat", "ViTImageProcessor"), + ("oneformer", "OneFormerImageProcessor"), ("owlvit", "OwlViTImageProcessor"), ("perceiver", "PerceiverImageProcessor"), ("poolformer", "PoolFormerImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4465097dfeed..5a164624be63 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -120,6 +120,7 @@ ("nezha", "NezhaModel"), ("nllb", "M2M100Model"), ("nystromformer", "NystromformerModel"), + ("oneformer", "OneFormerModel"), ("openai-gpt", "OpenAIGPTModel"), ("opt", "OPTModel"), ("owlvit", "OwlViTModel"), @@ -457,6 +458,7 @@ ("detr", "DetrForSegmentation"), ("mask2former", "Mask2FormerForUniversalSegmentation"), ("maskformer", "MaskFormerForInstanceSegmentation"), + ("oneformer", "OneFormerForUniversalSegmentation"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index f1ad8f221adf..a42b510ee756 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -53,6 +53,7 @@ ("layoutlmv3", "LayoutLMv3Processor"), ("layoutxlm", "LayoutXLMProcessor"), ("markuplm", "MarkupLMProcessor"), + ("oneformer", "OneFormerProcessor"), ("owlvit", "OwlViTProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 0b21273ca96c..94da66961c0e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -208,6 +208,7 @@ "AlbertTokenizerFast" if is_tokenizers_available() else None, ), ), + ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/oneformer/__init__.py b/src/transformers/models/oneformer/__init__.py new file mode 100644 index 000000000000..5530e7088559 --- /dev/null +++ b/src/transformers/models/oneformer/__init__.py @@ -0,0 +1,77 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_oneformer": ["ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "OneFormerConfig"], + "processing_oneformer": ["OneFormerProcessor"], +} + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_oneformer"] = ["OneFormerImageProcessor"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_oneformer"] = [ + "ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "OneFormerForUniversalSegmentation", + "OneFormerModel", + "OneFormerPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_oneformer import ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, OneFormerConfig + from .processing_oneformer import OneFormerProcessor + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_oneformer import OneFormerImageProcessor + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_oneformer import ( + ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + OneFormerForUniversalSegmentation, + OneFormerModel, + OneFormerPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/oneformer/configuration_oneformer.py b/src/transformers/models/oneformer/configuration_oneformer.py new file mode 100644 index 000000000000..67bbe8044a55 --- /dev/null +++ b/src/transformers/models/oneformer/configuration_oneformer.py @@ -0,0 +1,266 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OneFormer model configuration""" +import copy +from typing import Dict, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + +ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "shi-labs/oneformer_ade20k_swin_tiny": ( + "https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny/blob/main/config.json" + ), + # See all OneFormer models at https://huggingface.co/models?filter=oneformer +} + + +class OneFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OneFormerModel`]. It is used to instantiate a + OneFormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OneFormer + [shi-labs/oneformer_ade20k_swin_tiny](https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny) architecture + trained on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`) + The configuration of the backbone model. + ignore_value (`int`, *optional*, defaults to 255) + Values to be ignored in GT label while calculating loss. + num_queries (`int`, *optional*, defaults to 150) + Number of object queries. + no_object_weight (`float`, *optional*, defaults to 0.1) + Weight for no-object class predictions. + class_weight (`float`, *optional*, defaults to 2.0) + Weight for Classification CE loss. + mask_weight (`float`, *optional*, defaults to 5.0) + Weight for binary CE loss. + dice_weight (`float`, *optional*, defaults to 5.0) + Weight for dice loss. + contrastive_weight (`float`, *optional*, defaults to 0.5) + Weight for contrastive loss. + contrastive_temperature (`float`, *optional*, defaults to 0.07) + Initial value for scaling the contrastive logits. + train_num_points (`int`, *optional*, defaults to 12544) + Number of points to sample while calculating losses on mask predictions. + oversample_ratio (`float`, *optional*, defaults to 3.0) + Ratio to decide how many points to oversample. + importance_sample_ratio (`float`, *optional*, defaults to 0.75) + Ratio of points that are sampled via importance sampling. + init_std (`float`, *optional*, defaults to 0.02) + Standard deviation for normal intialization. + init_xavier_std (`float`, *optional*, defaults to 0.02) + Standard deviation for xavier uniform initialization. + layer_norm_eps (`float`, *optional*, defaults to 1e-05) + Epsilon for layer normalization. + is_training (`bool`, *optional*, defaults to False) + Whether to run in training or inference mode. + use_auxiliary_loss (`bool`, *optional*, defaults to True) + Whether to calculate loss using intermediate predictions from transformer decoder. + output_auxiliary_logits (`bool`, *optional*, defaults to True) + Whether to return intermediate predictions from transformer decoder. + strides (`list`, *optional*, defaults to [4, 8, 16, 32]) + List containing the strides for feature maps in the encoder. + task_seq_len (`int`, *optional*, defaults to 77) + Sequence length for tokenizing text list input. + max_seq_len (`int`, *optional*, defaults to 77) + Sequence length for tokenizing task input. + text_encoder_width (`int`, *optional*, defaults to 256) + Hidden size for text encoder. + text_encoder_context_length (`int`, *optional*, defaults to 77): + Input sequence length for text encoder. + text_encoder_num_layers (`int`, *optional*, defaults to 6) + Number of layers for transformer in text encoder. + text_encoder_vocab_size (`int`, *optional*, defaults to 49408) + Vocabulary size for tokenizer. + text_encoder_proj_layers (`int`, *optional*, defaults to 2) + Number of layers in MLP for project text queries. + text_encoder_n_ctx (`int`, *optional*, defaults to 16) + Number of learnable text context queries. + conv_dim (`int`, *optional*, defaults to 256) + Feature map dimension to map outputs from the backbone. + mask_dim (`int`, *optional*, defaults to 256) + Dimension for feature maps in pixel decoder. + hidden_dim (`int`, *optional*, defaults to 256) + Dimension for hidden states in transformer decoder. + encoder_feedforward_dim (`int`, *optional*, defaults to 1024) + Dimension for FFN layer in pixel decoder. + norm (`str`, *optional*, defaults to `GN`) + Type of normalization. + encoder_layers (`int`, *optional*, defaults to 6) + Number of layers in pixel decoder. + decoder_layers (`int`, *optional*, defaults to 10) + Number of layers in transformer decoder. + use_task_norm (`bool`, *optional*, defaults to `True`) + Whether to normalize the task token. + num_attention_heads (`int`, *optional*, defaults to 8) + Number of attention heads in transformer layers in the pixel and transformer decoders. + dropout (`float`, *optional*, defaults to 0.1) + Dropout probability for pixel and transformer decoders. + dim_feedforward (`int`, *optional*, defaults to 2048) + Dimension for FFN layer in transformer decoder. + pre_norm (`bool`, *optional*, defaults to `False`) + Whether to normalize hidden states before attention layers in transformer decoder. + enforce_input_proj (`bool`, *optional*, defaults to `False`) + Whether to project hidden states in transformer decoder. + query_dec_layers (`int`, *optional*, defaults to 2) + Number of layers in query transformer. + common_stride (`int`, *optional*, defaults to 4) + Common stride used for features in pixel decoder. + + Examples: + ```python + >>> from transformers import OneFormerConfig, OneFormerModel + + >>> # Initializing a OneFormer shi-labs/oneformer_ade20k_swin_tiny configuration + >>> configuration = OneFormerConfig() + >>> # Initializing a model (with random weights) from the shi-labs/oneformer_ade20k_swin_tiny style configuration + >>> model = OneFormerModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + model_type = "oneformer" + attribute_map = {"hidden_size": "hidden_dim"} + + def __init__( + self, + backbone_config: Optional[Dict] = None, + ignore_value: int = 255, + num_queries: int = 150, + no_object_weight: int = 0.1, + class_weight: float = 2.0, + mask_weight: float = 5.0, + dice_weight: float = 5.0, + contrastive_weight: float = 0.5, + contrastive_temperature: float = 0.07, + train_num_points: int = 12544, + oversample_ratio: float = 3.0, + importance_sample_ratio: float = 0.75, + init_std: float = 0.02, + init_xavier_std: float = 1.0, + layer_norm_eps: float = 1e-05, + is_training: bool = False, + use_auxiliary_loss: bool = True, + output_auxiliary_logits: bool = True, + strides: Optional[list] = [4, 8, 16, 32], + task_seq_len: int = 77, + max_seq_len: int = 77, + text_encoder_width: int = 256, + text_encoder_context_length: int = 77, + text_encoder_num_layers: int = 6, + text_encoder_vocab_size: int = 49408, + text_encoder_proj_layers: int = 2, + text_encoder_n_ctx: int = 16, + conv_dim: int = 256, + mask_dim: int = 256, + hidden_dim: int = 256, + encoder_feedforward_dim: int = 1024, + norm: str = "GN", + encoder_layers: int = 6, + decoder_layers: int = 10, + use_task_norm: bool = True, + num_attention_heads: int = 8, + dropout: float = 0.1, + dim_feedforward: int = 2048, + pre_norm: bool = False, + enforce_input_proj: bool = False, + query_dec_layers: int = 2, + common_stride: int = 4, + **kwargs, + ): + if backbone_config is None: + logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.") + backbone_config = CONFIG_MAPPING["swin"]( + image_size=224, + in_channels=3, + patch_size=4, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + drop_path_rate=0.3, + use_absolute_embeddings=False, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + self.backbone_config = backbone_config + + self.ignore_value = ignore_value + self.num_queries = num_queries + self.no_object_weight = no_object_weight + self.class_weight = class_weight + self.mask_weight = mask_weight + self.dice_weight = dice_weight + self.contrastive_weight = contrastive_weight + self.contrastive_temperature = contrastive_temperature + self.train_num_points = train_num_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.layer_norm_eps = layer_norm_eps + self.is_training = is_training + self.use_auxiliary_loss = use_auxiliary_loss + self.output_auxiliary_logits = output_auxiliary_logits + self.strides = strides + self.task_seq_len = task_seq_len + self.max_seq_len = max_seq_len + self.text_encoder_width = text_encoder_width + self.text_encoder_context_length = text_encoder_context_length + self.text_encoder_num_layers = text_encoder_num_layers + self.text_encoder_vocab_size = text_encoder_vocab_size + self.text_encoder_proj_layers = text_encoder_proj_layers + self.text_encoder_n_ctx = text_encoder_n_ctx + self.conv_dim = conv_dim + self.mask_dim = mask_dim + self.hidden_dim = hidden_dim + self.encoder_feedforward_dim = encoder_feedforward_dim + self.norm = norm + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.use_task_norm = use_task_norm + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.dim_feedforward = dim_feedforward + self.pre_norm = pre_norm + self.enforce_input_proj = enforce_input_proj + self.query_dec_layers = query_dec_layers + self.common_stride = common_stride + self.num_hidden_layers = decoder_layers + + super().__init__(**kwargs) + + def to_dict(self) -> Dict[str, any]: + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["backbone_config"] = self.backbone_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/oneformer/convert_to_hf_oneformer.py b/src/transformers/models/oneformer/convert_to_hf_oneformer.py new file mode 100644 index 000000000000..074cc659b268 --- /dev/null +++ b/src/transformers/models/oneformer/convert_to_hf_oneformer.py @@ -0,0 +1,1192 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert OneFormer checkpoints from the original repository. URL: https://github.com/SHI-Labs/OneFormer""" + +import os +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path +from pprint import pformat +from typing import Any, Dict, Iterator, List, Set, Tuple + +import torch +import torchvision.transforms as T +from PIL import Image +from torch import Tensor, nn + +import requests + + +try: + from detectron2.checkpoint import DetectionCheckpointer + from detectron2.config import get_cfg + from detectron2.data import MetadataCatalog + from detectron2.projects.deeplab import add_deeplab_config +except ImportError: + pass +from transformers import CLIPTokenizer, DinatConfig, SwinConfig +from transformers.models.oneformer.image_processing_oneformer import OneFormerImageProcessor +from transformers.models.oneformer.modeling_oneformer import ( + OneFormerConfig, + OneFormerForUniversalSegmentation, + OneFormerForUniversalSegmentationOutput, + OneFormerModel, + OneFormerModelOutput, +) +from transformers.models.oneformer.processing_oneformer import OneFormerProcessor +from transformers.utils import logging + + +StateDict = Dict[str, Tensor] + +logging.set_verbosity_info() +logger = logging.get_logger() + +torch.manual_seed(0) + + +class TrackedStateDict: + def __init__(self, to_track: Dict): + """This class "tracks" a python dictionary by keeping track of which item is accessed. + + Args: + to_track (Dict): The dictionary we wish to track + """ + self.to_track = to_track + self._seen: Set[str] = set() + + def __getitem__(self, key: str) -> Any: + return self.to_track[key] + + def __setitem__(self, key: str, item: Any): + self._seen.add(key) + self.to_track[key] = item + + def diff(self) -> List[str]: + """This method returns a set difference between the keys in the tracked state dict and the one we have access so far. + This is an effective method to check if we have update all the keys + + Returns: + List[str]: List of keys not yet updated + """ + return set(list(self.to_track.keys())) - self._seen + + def copy(self) -> Dict: + # proxy the call to the internal dictionary + return self.to_track.copy() + + +# Image to verify the result +def prepare_img(): + url = "https://praeclarumjj3.github.io/files/coco.jpeg" + img_data = requests.get(url, stream=True).raw + im = Image.open(img_data) + return im + + +@dataclass +class Args: + """Fake command line arguments needed by oneformer/detectron2 implementation""" + + config_file: str + + +def setup_cfg(args: Args): + # load config from file and command-line arguments + cfg = get_cfg() + add_deeplab_config(cfg) + add_common_config(cfg) + add_oneformer_config(cfg) + add_swin_config(cfg) + add_dinat_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.freeze() + return cfg + + +class OriginalOneFormerConfigToOursConverter: + def __call__(self, original_config: object, is_swin: bool) -> OneFormerConfig: + model = original_config.MODEL + + dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0]) + id2label = {idx: label for idx, label in enumerate(dataset_catalog.stuff_classes)} + label2id = {label: idx for idx, label in id2label.items()} + + if is_swin: + if model.SWIN.EMBED_DIM == 96: + backbone_config = SwinConfig.from_pretrained( + "microsoft/swin-tiny-patch4-window7-224", + drop_path_rate=model.SWIN.DROP_PATH_RATE, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + elif model.SWIN.EMBED_DIM == 192: + backbone_config = SwinConfig.from_pretrained( + "microsoft/swin-large-patch4-window12-384", + drop_path_rate=model.SWIN.DROP_PATH_RATE, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + else: + raise ValueError(f"embed dim {model.SWIN.EMBED_DIM} not supported for Swin!") + else: + backbone_config = DinatConfig.from_pretrained( + "shi-labs/dinat-large-11x11-in22k-in1k-384", + dilations=model.DiNAT.DILATIONS, + kernel_size=model.DiNAT.KERNEL_SIZE, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + + config: OneFormerConfig = OneFormerConfig( + backbone_config=backbone_config, + output_attentions=True, + output_hidden_states=True, + return_dict=True, + ignore_value=model.SEM_SEG_HEAD.IGNORE_VALUE, + num_classes=model.SEM_SEG_HEAD.NUM_CLASSES, + num_queries=model.ONE_FORMER.NUM_OBJECT_QUERIES, + no_object_weight=model.ONE_FORMER.NO_OBJECT_WEIGHT, + class_weight=model.ONE_FORMER.CLASS_WEIGHT, + mask_weight=model.ONE_FORMER.MASK_WEIGHT, + dice_weight=model.ONE_FORMER.DICE_WEIGHT, + contrastive_weight=model.ONE_FORMER.CONTRASTIVE_WEIGHT, + contrastive_temperature=model.ONE_FORMER.CONTRASTIVE_TEMPERATURE, + train_num_points=model.ONE_FORMER.TRAIN_NUM_POINTS, + oversample_ratio=model.ONE_FORMER.OVERSAMPLE_RATIO, + importance_sample_ratio=model.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO, + init_std=0.02, + init_xavier_std=1.0, + layer_norm_eps=1e-05, + is_training=False, + use_auxiliary_loss=model.ONE_FORMER.DEEP_SUPERVISION, + output_auxiliary_logits=True, + strides=[4, 8, 16, 32], + task_seq_len=original_config.INPUT.TASK_SEQ_LEN, + max_seq_len=original_config.INPUT.MAX_SEQ_LEN, + text_encoder_width=model.TEXT_ENCODER.WIDTH, + text_encoder_context_length=model.TEXT_ENCODER.CONTEXT_LENGTH, + text_encoder_num_layers=model.TEXT_ENCODER.NUM_LAYERS, + text_encoder_vocab_size=model.TEXT_ENCODER.VOCAB_SIZE, + text_encoder_proj_layers=model.TEXT_ENCODER.PROJ_NUM_LAYERS, + text_encoder_n_ctx=model.TEXT_ENCODER.N_CTX, + conv_dim=model.SEM_SEG_HEAD.CONVS_DIM, + mask_dim=model.SEM_SEG_HEAD.MASK_DIM, + hidden_dim=model.ONE_FORMER.HIDDEN_DIM, + norm=model.SEM_SEG_HEAD.NORM, + encoder_layers=model.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS, + encoder_feedforward_dim=1024, + decoder_layers=model.ONE_FORMER.DEC_LAYERS, + use_task_norm=model.ONE_FORMER.USE_TASK_NORM, + num_attention_heads=model.ONE_FORMER.NHEADS, + dropout=model.ONE_FORMER.DROPOUT, + dim_feedforward=model.ONE_FORMER.DIM_FEEDFORWARD, + pre_norm=model.ONE_FORMER.PRE_NORM, + enforce_input_proj=model.ONE_FORMER.ENFORCE_INPUT_PROJ, + query_dec_layers=model.ONE_FORMER.CLASS_DEC_LAYERS, + common_stride=model.SEM_SEG_HEAD.COMMON_STRIDE, + id2label=id2label, + label2id=label2id, + ) + + return config + + +class OriginalOneFormerConfigToProcessorConverter: + def __call__(self, original_config: object, model_repo: str) -> OneFormerProcessor: + model = original_config.MODEL + model_input = original_config.INPUT + dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0]) + + if "ade20k" in model_repo: + class_info_file = "ade20k_panoptic.json" + elif "coco" in model_repo: + class_info_file = "coco_panoptic.json" + elif "cityscapes" in model_repo: + class_info_file = "cityscapes_panoptic.json" + else: + raise ValueError("Invalid Dataset!") + + image_processor = OneFormerImageProcessor( + image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), + image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), + size=model_input.MIN_SIZE_TEST, + max_size=model_input.MAX_SIZE_TEST, + num_labels=model.SEM_SEG_HEAD.NUM_CLASSES, + ignore_index=dataset_catalog.ignore_label, + class_info_file=class_info_file, + ) + + tokenizer = CLIPTokenizer.from_pretrained(model_repo) + + return OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + task_seq_length=original_config.INPUT.TASK_SEQ_LEN, + max_seq_length=original_config.INPUT.MAX_SEQ_LEN, + ) + + +class OriginalOneFormerCheckpointToOursConverter: + def __init__(self, original_model: nn.Module, config: OneFormerConfig): + self.original_model = original_model + self.config = config + + def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict): + for src_key, dst_key in renamed_keys: + dst_state_dict[dst_key] = src_state_dict.pop(src_key) + + # Swin Backbone + def replace_swin_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig): + dst_prefix: str = "pixel_level_module.encoder" + src_prefix: str = "backbone" + + renamed_keys = [ + ( + f"{src_prefix}.patch_embed.proj.weight", + f"{dst_prefix}.embeddings.patch_embeddings.projection.weight", + ), + (f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.embeddings.patch_embeddings.projection.bias"), + (f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.embeddings.norm.weight"), + (f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.embeddings.norm.bias"), + ] + num_layers = len(config.backbone_config.depths) + for layer_idx in range(num_layers): + for block_idx in range(config.backbone_config.depths[layer_idx]): + renamed_keys.extend( + [ # src, dst + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table", + ), + ] + ) + # now we need to handle the attentions + # read in weights + bias of input projection layer of cross-attention + + src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"] + src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"] + + size = src_att_weight.shape[0] + offset = size // 3 + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight" + ] = src_att_weight[:offset, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias" + ] = src_att_bias[:offset] + + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight" + ] = src_att_weight[offset : offset * 2, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias" + ] = src_att_bias[offset : offset * 2] + + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight" + ] = src_att_weight[-offset:, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias" + ] = src_att_bias[-offset:] + + # let's pop them + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight") + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias") + # proj + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias", + ), + ] + ) + + # second norm + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias", + ), + ] + ) + + # mlp + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias", + ), + ] + ) + + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index", + ) + ] + ) + + if layer_idx < num_layers - 1: + # patch merging + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.reduction.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.bias", + ), + ] + ) + + # hidden states norms + renamed_keys.extend( + [ + ( + f"{src_prefix}.norm{layer_idx}.weight", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.weight", + ), + ( + f"{src_prefix}.norm{layer_idx}.bias", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.bias", + ), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Dinat Backbone + def replace_dinat_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig): + dst_prefix: str = "pixel_level_module.encoder" + src_prefix: str = "backbone" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + renamed_keys = rename_keys_for_weight_bias(f"{src_prefix}.patch_embed.norm", f"{dst_prefix}.embeddings.norm") + + for i in range(2): + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.patch_embed.proj.{i}", + f"{dst_prefix}.embeddings.patch_embeddings.projection.{i}", + ) + ) + + num_layers = len(config.backbone_config.depths) + for layer_idx in range(num_layers): + for block_idx in range(config.backbone_config.depths[layer_idx]): + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm1", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_before", + ) + ) + + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm2", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_after", + ) + ) + + renamed_keys.extend( + [ # src, dst + ( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.rpb", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.rpb", + ), + ] + ) + # now we need to handle the attentions + # read in weights + bias of input projection layer of cross-attention + + src_att_weight = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"] + src_att_bias = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"] + + size = src_att_weight.shape[0] + offset = size // 3 + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.weight" + ] = src_att_weight[:offset, :] + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.bias" + ] = src_att_bias[:offset] + + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.weight" + ] = src_att_weight[offset : offset * 2, :] + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.bias" + ] = src_att_bias[offset : offset * 2] + + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.weight" + ] = src_att_weight[-offset:, :] + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.bias" + ] = src_att_bias[-offset:] + + # let's pop them + src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight") + src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias") + # proj + + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.proj", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.output.dense", + ) + ) + + # mlp + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc1", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.intermediate.dense", + ) + ) + + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc2", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.output.dense", + ) + ) + + if layer_idx < num_layers - 1: + # patch merging + renamed_keys.extend( + [ + ( + f"{src_prefix}.levels.{layer_idx}.downsample.reduction.weight", + f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.reduction.weight", + ), + ( + f"{src_prefix}.levels.{layer_idx}.downsample.norm.weight", + f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.weight", + ), + ( + f"{src_prefix}.levels.{layer_idx}.downsample.norm.bias", + f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.bias", + ), + ] + ) + + # hidden states norms + renamed_keys.extend( + [ + ( + f"{src_prefix}.norm{layer_idx}.weight", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.weight", + ), + ( + f"{src_prefix}.norm{layer_idx}.bias", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.bias", + ), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Backbone + Pixel Decoder + def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict, is_swin: bool): + dst_prefix: str = "pixel_level_module.decoder" + src_prefix: str = "sem_seg_head.pixel_decoder" + + if is_swin: + self.replace_swin_backbone(dst_state_dict, src_state_dict, self.config) + else: + self.replace_dinat_backbone(dst_state_dict, src_state_dict, self.config) + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str): + self_attn_keys = [] + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.attention_weights", f"{dst_prefix}.attention_weights") + ) + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.output_proj", f"{dst_prefix}.output_proj") + ) + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.sampling_offsets", f"{dst_prefix}.sampling_offsets") + ) + self_attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.value_proj", f"{dst_prefix}.value_proj")) + + return self_attn_keys + + def rename_keys_for_encoder_layer(src_prefix: str, dst_prefix: str): + encoder_keys = [] + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.fc1")) + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.fc2")) + encoder_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.self_attn_layer_norm") + ) + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.final_layer_norm")) + encoder_keys.extend(rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")) + + return encoder_keys + + # convolution layer for final features + renamed_keys = [ + (f"{src_prefix}.adapter_1.weight", f"{dst_prefix}.adapter_1.0.weight"), + (f"{src_prefix}.adapter_1.norm.weight", f"{dst_prefix}.adapter_1.1.weight"), + (f"{src_prefix}.adapter_1.norm.bias", f"{dst_prefix}.adapter_1.1.bias"), + ] + + renamed_keys.extend( + [ + (f"{src_prefix}.layer_1.weight", f"{dst_prefix}.layer_1.0.weight"), + (f"{src_prefix}.layer_1.norm.weight", f"{dst_prefix}.layer_1.1.weight"), + (f"{src_prefix}.layer_1.norm.bias", f"{dst_prefix}.layer_1.1.bias"), + ] + ) + + # proj layers + for i in range(3): + for j in range(2): + renamed_keys.extend( + [ + (f"{src_prefix}.input_proj.{i}.{j}.weight", f"{dst_prefix}.input_projections.{i}.{j}.weight"), + (f"{src_prefix}.input_proj.{i}.{j}.bias", f"{dst_prefix}.input_projections.{i}.{j}.bias"), + ] + ) + + renamed_keys.extend([(f"{src_prefix}.transformer.level_embed", f"{dst_prefix}.level_embed")]) + + # layers + for layer_idx in range(self.config.encoder_layers): + renamed_keys.extend( + rename_keys_for_encoder_layer( + f"{src_prefix}.transformer.encoder.layers.{layer_idx}", f"{dst_prefix}.encoder.layers.{layer_idx}" + ) + ) + + # proj + renamed_keys.extend( + [ + (f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"), + (f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Transformer Decoder + def replace_keys_qkv_transformer_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder.layers" + src_prefix: str = "sem_seg_head.predictor" + for i in range(self.config.decoder_layers - 1): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = src_state_dict.pop( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_weight" + ) + in_proj_bias = src_state_dict.pop( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_bias" + ) + # next, add query, keys and values (in that order) to the state dict + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.bias"] = in_proj_bias[:256] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.bias"] = in_proj_bias[256:512] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.bias"] = in_proj_bias[-256:] + + def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module" + src_prefix: str = "sem_seg_head.predictor" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + def rename_keys_for_attn(src_prefix: str, dst_prefix: str): + attn_keys = [ + (f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"), + (f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"), + ] + attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj")) + + return attn_keys + + def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str): + attn_keys = [] + attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj")) + + return attn_keys + + def rename_keys_for_query_transformer_layer(src_prefix: str, dst_prefix: str): + query_transformer_layer_keys = [] + + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.norm1") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.norm2") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm3", f"{dst_prefix}.norm3") + ) + + query_transformer_layer_keys.extend( + rename_keys_for_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn") + ) + + query_transformer_layer_keys.extend( + rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn") + ) + + return query_transformer_layer_keys + + def rename_keys_for_cross_attn_layer(src_prefix: str, dst_prefix: str): + cross_attn_layer_keys = [] + + cross_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm")) + cross_attn_layer_keys.extend( + rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn") + ) + + return cross_attn_layer_keys + + def rename_keys_for_self_attn_layer(src_prefix: str, dst_prefix: str): + self_attn_layer_keys = [] + + self_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm")) + self_attn_layer_keys.extend( + rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn") + ) + + return self_attn_layer_keys + + def rename_keys_for_ffn_layer(src_prefix: str, dst_prefix: str): + ffn_layer_keys = [] + + ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1")) + ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2")) + ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm")) + + return ffn_layer_keys + + def rename_keys_for_transformer_decoder_layer(src_prefix: str, dst_prefix: str, idx: int): + transformer_decoder_layer_keys = [] + + transformer_decoder_layer_keys.extend( + rename_keys_for_cross_attn_layer( + f"{src_prefix}.transformer_cross_attention_layers.{idx}", f"{dst_prefix}.{idx}.cross_attn" + ) + ) + + transformer_decoder_layer_keys.extend( + rename_keys_for_self_attn_layer( + f"{src_prefix}.transformer_self_attention_layers.{idx}", f"{dst_prefix}.{idx}.self_attn" + ) + ) + + transformer_decoder_layer_keys.extend( + rename_keys_for_ffn_layer(f"{src_prefix}.transformer_ffn_layers.{idx}", f"{dst_prefix}.{idx}.ffn") + ) + + return transformer_decoder_layer_keys + + # positional embedding for object queries + renamed_keys = [ + (f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"), + (f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"), + ] + + # norm + renamed_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.decoder_norm", f"{dst_prefix}.decoder.decoder_norm") + ) + + # proj + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.class_input_proj", f"{dst_prefix}.decoder.query_input_projection" + ) + ) + + renamed_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.class_embed", f"{dst_prefix}.decoder.class_embed") + ) + + for i in range(3): + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.mask_embed.layers.{i}", f"{dst_prefix}.decoder.mask_embed.layers.{i}.0" + ) + ) + + # norm + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.class_transformer.decoder.norm", f"{dst_prefix}.decoder.query_transformer.decoder.norm" + ) + ) + + # transformer to update queries with task tokens + for i in range(self.config.query_dec_layers): + renamed_keys.extend( + rename_keys_for_query_transformer_layer( + f"{src_prefix}.class_transformer.decoder.layers.{i}", + f"{dst_prefix}.decoder.query_transformer.decoder.layers.{i}", + ) + ) + + # decoder layers + for i in range(self.config.decoder_layers - 1): + renamed_keys.extend( + rename_keys_for_transformer_decoder_layer( + f"{src_prefix}", + f"{dst_prefix}.decoder.layers", + i, + ) + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict) + + def replace_task_mlp(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "task_encoder" + src_prefix: str = "task_mlp" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + renamed_keys = [] + + for i in range(2): + renamed_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.task_mlp.layers.{i}.0") + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_text_projector(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "text_mapper.text_projector" + src_prefix: str = "text_projector" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + renamed_keys = [] + + for i in range(self.config.text_encoder_config["text_encoder_proj_layers"]): + renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.{i}.0")) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_text_mapper(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "text_mapper.text_encoder" + src_prefix: str = "text_encoder" + + self.replace_text_projector(dst_state_dict, src_state_dict) + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + def rename_keys_for_attn(src_prefix: str, dst_prefix: str): + attn_keys = [ + (f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"), + (f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"), + ] + attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj")) + + return attn_keys + + def rename_keys_for_layer(src_prefix: str, dst_prefix: str): + resblock_keys = [] + + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_fc", f"{dst_prefix}.mlp.fc1")) + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_proj", f"{dst_prefix}.mlp.fc2")) + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_1", f"{dst_prefix}.layer_norm1")) + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_2", f"{dst_prefix}.layer_norm2")) + resblock_keys.extend(rename_keys_for_attn(f"{src_prefix}.attn", f"{dst_prefix}.self_attn")) + + return resblock_keys + + renamed_keys = [ + ("prompt_ctx.weight", "text_mapper.prompt_ctx.weight"), + ] + + renamed_keys.extend( + [ + (f"{src_prefix}.positional_embedding", f"{dst_prefix}.positional_embedding"), + (f"{src_prefix}.token_embedding.weight", f"{dst_prefix}.token_embedding.weight"), + ] + ) + + renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_final", f"{dst_prefix}.ln_final")) + + for i in range(self.config.text_encoder_config["text_encoder_num_layers"]): + renamed_keys.extend( + rename_keys_for_layer( + f"{src_prefix}.transformer.resblocks.{i}", f"{dst_prefix}.transformer.layers.{i}" + ) + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def convert(self, oneformer: OneFormerModel, is_swin: bool) -> OneFormerModel: + dst_state_dict = TrackedStateDict(oneformer.state_dict()) + src_state_dict = self.original_model.state_dict() + + self.replace_pixel_module(dst_state_dict, src_state_dict, is_swin) + self.replace_transformer_module(dst_state_dict, src_state_dict) + self.replace_task_mlp(dst_state_dict, src_state_dict) + if self.config.is_training: + self.replace_text_mapper(dst_state_dict, src_state_dict) + + logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}") + logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}") + logger.info("🙌 Done") + + oneformer.load_state_dict(dst_state_dict) + + return oneformer + + @staticmethod + def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]: + checkpoints: List[Path] = checkpoints_dir.glob("**/*.pth") + + for checkpoint in checkpoints: + logger.info(f"💪 Converting {checkpoint.stem}") + # find associated config file + config: Path = config_dir / f"{checkpoint.stem}.yaml" + + yield config, checkpoint + + +def post_process_sem_seg_output(outputs: OneFormerForUniversalSegmentationOutput, target_size: Tuple[int, int]): + # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1] + class_queries_logits = outputs.class_queries_logits + # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_queries_logits = outputs.masks_queries_logits + if target_size is not None: + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, + size=target_size, + mode="bilinear", + align_corners=False, + ) + # remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_probs = masks_queries_logits.sigmoid() + # now we want to sum over the queries, + # $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $ + # where $ softmax(p) \in R^{q, c} $ is the mask classes + # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities + # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + return segmentation + + +def test( + original_model, + our_model: OneFormerForUniversalSegmentation, + processor: OneFormerProcessor, + model_repo: str, +): + def _preprocess_text(text_list=None, max_length=77): + if text_list is None: + raise ValueError("tokens cannot be None.") + + tokens = tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True) + + attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"] + + token_inputs = [] + for attn_mask, input_id in zip(attention_masks, input_ids): + token = torch.tensor(attn_mask) * torch.tensor(input_id) + token_inputs.append(token.unsqueeze(0)) + + token_inputs = torch.cat(token_inputs, dim=0) + return token_inputs + + with torch.no_grad(): + tokenizer = CLIPTokenizer.from_pretrained(model_repo) + original_model = original_model.eval() + our_model = our_model.eval() + + im = prepare_img() + + tr = T.Compose( + [ + T.Resize((640, 640)), + T.ToTensor(), + T.Normalize( + mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0, + std=torch.tensor([58.395, 57.120, 57.375]) / 255.0, + ), + ], + ) + + x = tr(im).unsqueeze(0) + + task_input = ["the task is semantic"] + task_token = _preprocess_text(task_input, max_length=processor.task_seq_length) + + original_model_backbone_features = original_model.backbone(x.clone()) + + our_model_output: OneFormerModelOutput = our_model.model(x.clone(), task_token, output_hidden_states=True) + + for original_model_feature, our_model_feature in zip( + original_model_backbone_features.values(), our_model_output.encoder_hidden_states + ): + assert torch.allclose( + original_model_feature, our_model_feature, atol=3e-3 + ), "The backbone features are not the same." + mask_features, _, multi_scale_features, _, _ = original_model.sem_seg_head.pixel_decoder.forward_features( + original_model_backbone_features + ) + + original_pixel_decoder_features = [] + original_pixel_decoder_features.append(mask_features) + for i in range(len(multi_scale_features)): + original_pixel_decoder_features.append(multi_scale_features[i]) + + for original_model_feature, our_model_feature in zip( + original_pixel_decoder_features, our_model_output.pixel_decoder_hidden_states + ): + assert torch.allclose( + original_model_feature, our_model_feature, atol=3e-4 + ), "The pixel decoder feature are not the same" + + tr_complete = T.Compose( + [ + T.Resize((640, 640)), + T.ToTensor(), + ], + ) + + y = (tr_complete(im) * 255.0).to(torch.int).float() + + # let's test the full model + original_model_out = original_model([{"image": y.clone(), "task": "The task is semantic"}]) + + original_segmentation = original_model_out[0]["sem_seg"] + + our_model_out: OneFormerForUniversalSegmentationOutput = our_model( + x.clone(), task_token, output_hidden_states=True + ) + + our_segmentation = post_process_sem_seg_output(our_model_out, target_size=(640, 640))[0] + + assert torch.allclose( + original_segmentation, our_segmentation, atol=1e-3 + ), "The segmentation image is not the same." + + logger.info("✅ Test passed!") + + +def get_name(checkpoint_file: Path): + model_name_raw: str = checkpoint_file.stem + + backbone = "swin" if "swin" in model_name_raw else "dinat" + dataset = "" + if "coco" in model_name_raw: + dataset = "coco" + elif "ade20k" in model_name_raw: + dataset = "ade20k" + elif "cityscapes" in model_name_raw: + dataset = "cityscapes" + else: + raise ValueError( + f"{model_name_raw} must be wrong since we didn't find 'coco' or 'ade20k' or 'cityscapes' in it " + ) + + backbone_types = ["tiny", "large"] + + backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0] + + model_name = f"oneformer_{dataset}_{backbone}_{backbone_type}" + + return model_name + + +if __name__ == "__main__": + parser = ArgumentParser( + description=( + "Command line to convert the original oneformer models (with swin backbone) to transformers" + " implementation." + ) + ) + + parser.add_argument( + "--checkpoints_dir", + type=Path, + help=( + "A directory containing the model's checkpoints. The directory has to have the following structure:" + " structure: //.pth; where name must follow the" + " following nomenclature nomenclature: oneformer___" + ), + ) + parser.add_argument( + "--configs_dir", + type=Path, + help=( + "A directory containing the model's configs, see detectron2 doc. The directory has to have the following" + " structure: //.yaml; where name must follow the" + " following nomenclature nomenclature: oneformer___" + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=Path, + help="Path to the folder to output PyTorch models.", + ) + parser.add_argument( + "--oneformer_dir", + required=True, + type=Path, + help=( + "A path to OneFormer's original implementation directory. You can download from here:" + "https://github.com/SHI-Labs/OneFormer" + ), + ) + + args = parser.parse_args() + + checkpoints_dir: Path = args.checkpoints_dir + config_dir: Path = args.configs_dir + save_directory: Path = args.pytorch_dump_folder_path + oneformer_dir: Path = args.oneformer_dir + # append the path to the parents to oneformer dir + sys.path.append(str(oneformer_dir.parent)) + # and import what's needed + from OneFormer.oneformer import add_common_config, add_dinat_config, add_oneformer_config, add_swin_config + from OneFormer.oneformer.oneformer_model import OneFormer as OriginalOneFormer + + if not save_directory.exists(): + save_directory.mkdir(parents=True) + + for config_file, checkpoint_file in OriginalOneFormerCheckpointToOursConverter.using_dirs( + checkpoints_dir, config_dir + ): + processor = OriginalOneFormerConfigToProcessorConverter()( + setup_cfg(Args(config_file=config_file)), os.path.join("shi-labs", config_file.stem) + ) + + original_config = setup_cfg(Args(config_file=config_file)) + oneformer_kwargs = OriginalOneFormer.from_config(original_config) + + original_model = OriginalOneFormer(**oneformer_kwargs).eval() + + DetectionCheckpointer(original_model).load(str(checkpoint_file)) + + is_swin = "swin" in config_file.stem + + config: OneFormerConfig = OriginalOneFormerConfigToOursConverter()(original_config, is_swin) + + oneformer = OneFormerModel(config=config).eval() + + converter = OriginalOneFormerCheckpointToOursConverter(original_model, config) + + oneformer = converter.convert(oneformer, is_swin) + + oneformer_for_universal_segmentation = OneFormerForUniversalSegmentation(config=config).eval() + + oneformer_for_universal_segmentation.model = oneformer + + test( + original_model, + oneformer_for_universal_segmentation, + processor, + os.path.join("shi-labs", config_file.stem), + ) + + model_name = get_name(checkpoint_file) + logger.info(f"🪄 Saving {model_name}") + + processor.save_pretrained(save_directory / model_name) + oneformer_for_universal_segmentation.save_pretrained(save_directory / model_name) + + processor.push_to_hub( + repo_id=os.path.join("shi-labs", config_file.stem), + commit_message="Add configs", + use_temp_dir=True, + ) + oneformer_for_universal_segmentation.push_to_hub( + repo_id=os.path.join("shi-labs", config_file.stem), + commit_message="Add model", + use_temp_dir=True, + ) diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py new file mode 100644 index 000000000000..2cbe1ca5bf11 --- /dev/null +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -0,0 +1,1233 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for OneFormer.""" + +import json +import warnings +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import numpy as np + +from huggingface_hub import hf_hub_download +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import ( + PaddingMode, + get_resize_output_image_size, + normalize, + pad, + rescale, + resize, + to_channel_dimension_format, + to_numpy_array, +) +from transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_batched, + valid_images, +) +from transformers.utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + TensorType, + is_torch_available, + is_torch_tensor, + logging, +) + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + from torch import nn + + +# Copied from transformers.models.detr.image_processing_detr.max_across_indices +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.detr.image_processing_detr.get_max_height_width +def get_max_height_width(images: List[np.ndarray]) -> List[int]: + """ + Get the maximum height and width across all images in a batch. + """ + input_channel_dimension = infer_channel_dimension_format(images[0]) + + if input_channel_dimension == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_channel_dimension == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") + return (max_height, max_width) + + +# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask +def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`Tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle +def binary_mask_to_rle(mask): + """ + Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + mask (`torch.Tensor` or `numpy.array`): + A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target + segment_id or class_id. + Returns: + `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE + format. + """ + if is_torch_tensor(mask): + mask = mask.numpy() + + pixels = mask.flatten() + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return [x for x in runs] + + +# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle +def convert_segmentation_to_rle(segmentation): + """ + Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + segmentation (`torch.Tensor` or `numpy.array`): + A segmentation map of shape `(height, width)` where each value denotes a segment or class id. + Returns: + `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id. + """ + segment_ids = torch.unique(segmentation) + + run_length_encodings = [] + for idx in segment_ids: + mask = torch.where(segmentation == idx, 1, 0) + rle = binary_mask_to_rle(mask) + run_length_encodings.append(rle) + + return run_length_encodings + + +# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +# Copied from transformers.models.detr.image_processing_detr.check_segment_validity +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_area = (mask_probs[k] >= mask_threshold).sum() + mask_exists = mask_k_area > 0 and original_area > 0 + + # Eliminate disconnected tiny segments + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, mask_k + + +# Copied from transformers.models.detr.image_processing_detr.compute_segments +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_size: Tuple[int, int] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: List[Dict] = [] + + if target_size is not None: + mask_probs = nn.functional.interpolate( + mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False + )[0] + + current_segment_id = 0 + + # Weigh each mask by its prediction score + mask_probs *= pred_scores.view(-1, 1, 1) + mask_labels = mask_probs.argmax(0) # [height, width] + + # Keep track of instances of each class + stuff_memory_list: Dict[str, int] = {} + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + should_fuse = pred_class in label_ids_to_fuse + + # Check if mask exists and large enough to be a segment + mask_exists, mask_k = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if mask_exists: + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + + # Add current object segment to final segmentation map + segmentation[mask_k] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "was_fused": should_fuse, + "score": segment_score, + } + ) + if should_fuse: + stuff_memory_list[pred_class] = current_segment_id + + return segmentation, segments + + +# Copied from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks +def convert_segmentation_map_to_binary_masks( + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, +): + if reduce_labels and ignore_index is None: + raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.") + + if reduce_labels: + segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1) + + # Get unique ids (class or instance ids based on input) + all_labels = np.unique(segmentation_map) + + # Drop background label if applicable + if ignore_index is not None: + all_labels = all_labels[all_labels != ignore_index] + + # Generate a binary mask for each object instance + binary_masks = [(segmentation_map == i) for i in all_labels] + binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width) + + # Convert instance ids to class ids + if instance_id_to_semantic_id is not None: + labels = np.zeros(all_labels.shape[0]) + + for label in all_labels: + class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label] + labels[all_labels == label] = class_id - 1 if reduce_labels else class_id + else: + labels = all_labels + + return binary_masks.astype(np.float32), labels.astype(np.int64) + + +def get_oneformer_resize_output_image_size( + image: np.ndarray, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], + max_size: Optional[int] = None, + default_to_square: bool = True, +) -> tuple: + """ + Computes the output size given the desired size. + + Args: + input_image (`np.ndarray`): + The input image. + size (`int`, `Tuple[int, int]`, `List[int]`, `Tuple[int]`): + The size of the output image. + default_to_square (`bool`, *optional*, defaults to `True`): + Whether to default to square if no size is provided. + max_size (`int`, *optional*): + The maximum size of the output image. + + Returns: + `Tuple[int, int]`: The output size. + """ + output_size = get_resize_output_image_size( + input_image=image, size=size, default_to_square=default_to_square, max_size=max_size + ) + return output_size + + +def prepare_metadata(repo_path, class_info_file): + with open(hf_hub_download(repo_path, class_info_file, repo_type="dataset"), "r") as f: + class_info = json.load(f) + metadata = {} + class_names = [] + thing_ids = [] + for key, info in class_info.items(): + metadata[key] = info["name"] + class_names.append(info["name"]) + if info["isthing"]: + thing_ids.append(int(key)) + metadata["thing_ids"] = thing_ids + metadata["class_names"] = class_names + return metadata + + +class OneFormerImageProcessor(BaseImageProcessor): + r""" + Constructs a OneFormer image processor. The image processor can be used to prepare image(s), task input(s) and + optional text inputs and targets for the model. + + This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 800): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a + sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of + the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size * + height / width, size)`. + max_size (`int`, *optional*, defaults to 1333): + The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is + set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`, + `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`, + `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set + to `True`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input to a certain `scale`. + rescale_factor (`float`, *optional*, defaults to 1/ 255): + Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean. + image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the + ImageNet std. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). + The background label will be replaced by `ignore_index`. + repo_path (`str`, defaults to `shi-labs/oneformer_demo`): + Dataset repository on huggingface hub containing the JSON file with class information for the dataset. + class_info_file (`str`): + JSON file containing class information for the dataset. It is stored inside on the `repo_path` dataset + repository. + num_text (`int`, *optional*): + Number of text entries in the text input list. + """ + + model_input_names = ["pixel_values", "pixel_mask", "task_inputs"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, + repo_path: str = "shi-labs/oneformer_demo", + class_info_file: str = None, + num_text: Optional[int] = None, + **kwargs + ): + if "max_size" in kwargs: + self._max_size = kwargs.pop("max_size") + else: + self._max_size = 1333 + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size} + size = get_size_dict(size, max_size=self._max_size, default_to_square=False) + + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.ignore_index = ignore_index + self.reduce_labels = reduce_labels + self.class_info_file = class_info_file + self.repo_path = repo_path + self.metadata = prepare_metadata(repo_path, class_info_file) + self.num_text = num_text + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format=None, + **kwargs + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + """ + if "max_size" in kwargs: + warnings.warn( + "The `max_size` parameter is deprecated and will be removed in v4.27. " + "Please specify in `size['longest_edge'] instead`.", + FutureWarning, + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + size, max_size = size["shortest_edge"], size["longest_edge"] + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + max_size = None + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + size = get_oneformer_resize_output_image_size( + image=image, + size=size, + max_size=max_size, + default_to_square=False, + ) + image = resize(image, size=size, resample=resample, data_format=data_format) + return image + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.rescale + def rescale( + self, image: np.ndarray, rescale_factor: float, data_format: Optional[ChannelDimension] = None + ) -> np.ndarray: + """ + Rescale the image by the given factor. + """ + return rescale(image, rescale_factor, data_format=data_format) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize + def normalize( + self, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[ChannelDimension] = None, + ) -> np.ndarray: + """ + Normalize the image with the given mean and standard deviation. + """ + return normalize(image, mean=mean, std=std, data_format=data_format) + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks + def convert_segmentation_map_to_binary_masks( + self, + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, + ): + reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + return convert_segmentation_map_to_binary_masks( + segmentation_map=segmentation_map, + instance_id_to_semantic_id=instance_id_to_semantic_id, + ignore_index=ignore_index, + reduce_labels=reduce_labels, + ) + + def __call__(self, images, task_inputs, segmentation_maps=None, **kwargs) -> BatchFeature: + return self.preprocess(images, task_inputs, segmentation_maps=segmentation_maps, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + ): + if do_resize: + image = self.resize(image, size=size, resample=resample) + if do_rescale: + image = self.rescale(image, rescale_factor=rescale_factor) + if do_normalize: + image = self.normalize(image, mean=image_mean, std=image_std) + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + image = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format) + return image + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + segmentation_map = to_numpy_array(segmentation_map) + # Add channel dimension if missing - needed for certain transformations + added_channel_dim = False + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + # TODO: (Amy) + # Remork segmentation map processing to include reducing labels and resizing which doesn't + # drop segment IDs > 255. + segmentation_map = self._preprocess( + image=segmentation_map, + do_resize=do_resize, + resample=PILImageResampling.NEAREST, + size=size, + do_rescale=False, + do_normalize=False, + ) + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + return segmentation_map + + def preprocess( + self, + images: ImageInput, + task_inputs: List[str], + segmentation_maps: Optional[ImageInput] = None, + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + ignore_index: Optional[int] = None, + reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + **kwargs + ) -> BatchFeature: + if "pad_and_return_pixel_mask" in kwargs: + warnings.warn( + "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version", + FutureWarning, + ) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False, max_size=self._max_size) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels + + if do_resize is not None and size is None: + raise ValueError("If `do_resize` is True, `size` must be provided.") + + if do_rescale is not None and rescale_factor is None: + raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.") + + if do_normalize is not None and (image_mean is None or image_std is None): + raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.") + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if not is_batched(images): + images = [images] + segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None + + if segmentation_maps is not None and len(images) != len(segmentation_maps): + raise ValueError("Images and segmentation maps must have the same length.") + + images = [ + self._preprocess_image( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + ) + for image in images + ] + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_mask(segmentation_map, do_resize, size) for segmentation_map in segmentation_maps + ] + encoded_inputs = self.encode_inputs( + images, + task_inputs, + segmentation_maps, + instance_id_to_semantic_id, + ignore_index, + reduce_labels, + return_tensors, + ) + return encoded_inputs + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image + def _pad_image( + self, + image: np.ndarray, + output_size: Tuple[int, int], + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format + ) + return padded_image + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad + def pad( + self, + images: List[np.ndarray], + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + ) -> np.ndarray: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + image (`np.ndarray`): + Image to pad. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + input_channel_dimension (`ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be inferred from the input image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + """ + pad_size = get_max_height_width(images) + + padded_images = [ + self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) + for image in images + ] + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] + data["pixel_mask"] = masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + def get_semantic_annotations(self, label, num_class_obj): + annotation_classes = label["classes"] + annotation_masks = label["masks"] + + texts = ["a semantic photo"] * self.num_text + classes = [] + masks = [] + + for idx in range(len(annotation_classes)): + class_id = annotation_classes[idx] + mask = annotation_masks[idx] + if not np.all(mask is False): + if class_id not in classes: + cls_name = self.metadata[str(class_id)] + classes.append(class_id) + masks.append(mask) + num_class_obj[cls_name] += 1 + else: + idx = classes.index(class_id) + masks[idx] += mask + masks[idx] = np.clip(masks[idx], 0, 1) + + num = 0 + for i, cls_name in enumerate(self.metadata["class_names"]): + if num_class_obj[cls_name] > 0: + for _ in range(num_class_obj[cls_name]): + if num >= len(texts): + break + texts[num] = f"a photo with a {cls_name}" + num += 1 + + classes = np.array(classes) + masks = np.array(masks) + return classes, masks, texts + + def get_instance_annotations(self, label, num_class_obj): + annotation_classes = label["classes"] + annotation_masks = label["masks"] + + texts = ["an instance photo"] * self.num_text + classes = [] + masks = [] + + for idx in range(len(annotation_classes)): + class_id = annotation_classes[idx] + mask = annotation_masks[idx] + + if class_id in self.metadata["thing_ids"]: + if not np.all(mask is False): + cls_name = self.metadata[str(class_id)] + classes.append(class_id) + masks.append(mask) + num_class_obj[cls_name] += 1 + + num = 0 + for i, cls_name in enumerate(self.metadata["class_names"]): + if num_class_obj[cls_name] > 0: + for _ in range(num_class_obj[cls_name]): + if num >= len(texts): + break + texts[num] = f"a photo with a {cls_name}" + num += 1 + + classes = np.array(classes) + masks = np.array(masks) + return classes, masks, texts + + def get_panoptic_annotations(self, label, num_class_obj): + annotation_classes = label["classes"] + annotation_masks = label["masks"] + + texts = ["an panoptic photo"] * self.num_text + classes = [] + masks = [] + + for idx in range(len(annotation_classes)): + class_id = annotation_classes[idx] + mask = annotation_masks[idx].data + if not np.all(mask is False): + cls_name = self.metadata[str(class_id)] + classes.append(class_id) + masks.append(mask) + num_class_obj[cls_name] += 1 + + num = 0 + for i, cls_name in enumerate(self.metadata["class_names"]): + if num_class_obj[cls_name] > 0: + for _ in range(num_class_obj[cls_name]): + if num >= len(texts): + break + texts[num] = f"a photo with a {cls_name}" + num += 1 + + classes = np.array(classes) + masks = np.array(masks) + return classes, masks, texts + + def encode_inputs( + self, + pixel_values_list: List[ImageInput], + task_inputs: List[str], + segmentation_maps: ImageInput = None, + instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ): + """ + Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. + + OneFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps + will be converted to lists of binary masks and their respective labels. Let's see an example, assuming + `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels = + [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for + each mask. + + Args: + pixel_values_list (`List[ImageInput]`): + List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, + width)`. + + task_inputs (`List[str]`): + List of task values. + + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + + (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an + instance segmentation map where each pixel represents an instance id. Can be provided as a single + dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map + instance ids in each image separately. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` + objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + - **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in + `self.model_input_names`). + - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model + (when `annotations` are provided). + - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when + `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of + `mask_labels[i][j]` if `class_labels[i][j]`. + - **text_inputs** -- Optional list of text string entries to be fed to a model (when `annotations` are + provided). They identify the binary masks present in the image. + """ + ignore_index = self.ignore_index if ignore_index is None else ignore_index + reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels + + if "pad_and_return_pixel_mask" in kwargs: + warnings.warn( + "The `pad_and_return_pixel_mask` argument has no effect and will be removed in v4.27", FutureWarning + ) + + pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] + pad_size = get_max_height_width(pixel_values_list) + encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors) + + annotations = None + if segmentation_maps is not None: + segmentation_maps = map(np.array, segmentation_maps) + annotations = [] + for idx, segmentation_map in enumerate(segmentation_maps): + # Use instance2class_id mapping per image + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = self.convert_segmentation_map_to_binary_masks( + segmentation_map, instance_id, ignore_index=ignore_index, reduce_labels=reduce_labels + ) + annotations.append({"masks": masks, "classes": classes}) + + if annotations is not None: + mask_labels = [] + class_labels = [] + text_inputs = [] + + num_class_obj = {} + for cls_name in self.metadata["class_names"]: + num_class_obj[cls_name] = 0 + + for i, label in enumerate(annotations): + task = task_inputs[i] + if task == "semantic": + classes, masks, texts = self.get_semantic_annotations(label, num_class_obj) + elif task == "instance": + classes, masks, texts = self.get_instance_annotations(label, num_class_obj) + if task == "panoptic": + classes, masks, texts = self.get_panoptic_annotations(label, num_class_obj) + + # we cannot batch them since they don't share a common class size + masks = [mask[None, ...] for mask in masks] + masks = [ + self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks + ] + masks = np.concatenate(masks, axis=0) + mask_labels.append(torch.from_numpy(masks)) + class_labels.append(torch.from_numpy(classes).long()) + text_inputs.append(texts) + + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels + encoded_inputs["text_inputs"] = text_inputs + + return encoded_inputs + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None + ) -> "torch.Tensor": + """ + Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + Returns: + `List[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = torch.nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_instance_segmentation( + self, + outputs, + task_type: str = "instance", + is_demo: bool = True, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[List[Tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + ): + """ + Converts the output of [`OneFormerForUniversalSegmentationOutput`] into image instance segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`OneFormerForUniversalSegmentationOutput`]): + The outputs from [`OneFormerForUniversalSegmentationOutput`]. + task_type (`str`, *optional)*, defaults to "instance"): + The post processing depends on the task token input. If the `task_type` is "panoptic", we need to + ignore the stuff predictions. + is_demo (`bool`, *optional)*, defaults to `True`): + Whether the model is in demo mode. If true, use threshold to predict final masks. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + return_coco_annotation (`bool`, *optional)*, defaults to `False`): + Whether to return predictions in COCO format. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_queries = class_queries_logits.shape[1] + num_classes = class_queries_logits.shape[-1] - 1 + + # Loop over items in batch size + results: List[Dict[str, torch.Tensor]] = [] + + for i in range(batch_size): + # [Q, K] + scores = torch.nn.functional.softmax(class_queries_logits[i], dim=-1)[:, :-1] + labels = torch.arange(num_classes).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + + # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False) + scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False) + labels_per_image = labels[topk_indices] + + topk_indices = topk_indices // num_classes + # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1) + mask_pred = masks_queries_logits[i][topk_indices] + + # Only consider scores with confidence over [threshold] for demo + if is_demo: + keep = scores_per_image > threshold + scores_per_image = scores_per_image[keep] + labels_per_image = labels_per_image[keep] + mask_pred = mask_pred[keep] + + # if this is panoptic segmentation, we only keep the "thing" classes + if task_type == "panoptic": + keep = torch.zeros_like(scores_per_image).bool() + for i, lab in enumerate(labels_per_image): + keep[i] = lab in self.metadata["thing_ids"] + + scores_per_image = scores_per_image[keep] + labels_per_image = labels_per_image[keep] + mask_pred = mask_pred[keep] + + if mask_pred.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_pred.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type: + for i in range(labels_per_image.shape[0]): + labels_per_image[i] = self.metadata["thing_ids"].index(labels_per_image[i].item()) + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_pred, + scores_per_image, + labels_per_image, + mask_threshold, + overlap_mask_area_threshold, + set(), + target_size, + ) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_panoptic_segmentation + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_sizes: Optional[List[Tuple[int, int]]] = None, + ) -> List[Dict]: + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentationOutput`]): + The outputs from [`MaskFormerForInstanceSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: List[Dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py new file mode 100644 index 000000000000..6167e0e15150 --- /dev/null +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -0,0 +1,3213 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OneFormer model.""" +import copy +import math +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor, nn +from torch.cuda.amp import autocast + +from transformers import AutoBackbone +from transformers.utils import logging + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_scipy_available, + replace_return_docstrings, + requires_backends, +) +from .configuration_oneformer import OneFormerConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "OneFormerConfig" +_CHECKPOINT_FOR_DOC = "shi-labs/oneformer_ade20k_swin_tiny" +_IMAGE_PROCESSOR_FOR_DOC = "OneFormerImageProcessor" + +ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "shi-labs/oneformer_ade20k_swin_tiny", + # See all OneFormer models at https://huggingface.co/models?filter=oneformer +] + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def multiscale_deform_attn_core_pytorch( + value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# Copied from transformers.models.maskformer.modeling_maskformer.dice_loss +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +# Copied from transformers.models.mask2former.modeling_mask2former.sigmoid_cross_entropy_loss +def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor: + r""" + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss. + """ + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss = criterion(inputs, labels) + + loss = cross_entropy_loss.mean(1).sum() / num_masks + return loss + + +# Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +# Copied from transformers.models.mask2former.modeling_mask2former.pair_wise_sigmoid_cross_entropy_loss +def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + r""" + A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss between each pairs. + """ + + height_and_width = inputs.shape[1] + + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) + cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + + loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum( + "nc,mc->nm", cross_entropy_loss_neg, (1 - labels) + ) + loss = loss / height_and_width + return loss + + +# Copied from transformers.models.mask2former.modeling_mask2former.sample_point +def sample_point( + input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs +) -> torch.Tensor: + """ + A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. + + Args: + input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): + A tensor that contains features map on a height * width grid + point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: + 2)): + A tensor that contains [0, 1] * [0, 1] normalized point coordinates + add_dim (`bool`): + boolean value to keep track of added dimension + + Returns: + point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, + height_grid, width_grid): + A tensor that contains features for points in `point_coordinates`. + """ + if point_coordinates.dim() == 3: + add_dim = True + point_coordinates = point_coordinates.unsqueeze(2) + + # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation + point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs) + if add_dim: + point_features = point_features.squeeze(3) + + return point_features + + +# Refactored from https://github.com/SHI-Labs/OneFormer/blob/33ebb56ed34f970a30ae103e786c0cb64c653d9a/oneformer/modeling/matcher.py#L93 +class OneFormerHungarianMatcher(nn.Module): + def __init__( + self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 + ): + """This class computes an assignment between the labels and the predictions of the network. + + For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more + predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + + Params: + cost_class (float, *optional*, defaults to 1.0): + This is the relative weight of the classification error in the matching cost. + cost_mask (float, *optional*, defaults to 1.0): + This is the relative weight of the sigmoid ce loss of the binary mask in the matching cost. + cost_dice (float, *optional*, defaults to 1.0): + This is the relative weight of the dice loss of the binary mask in the matching cost + num_points (int, *optional*, defaults to 12544): + Number of points to be sampled for dice and mask loss matching cost. + """ + super().__init__() + if cost_class == 0 and cost_mask == 0 and cost_dice == 0: + raise ValueError("All costs cant be 0") + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + self.num_points = num_points + + @torch.no_grad() + def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> List[Tuple[Tensor]]: + """Performs the matching + + Params: + masks_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, num_labels` with the + classification logits. + class_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, height, width` with the + predicted masks. + + class_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes` (where num_target_boxes is the number + of ground-truth objects in the target) containing the class labels. + mask_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes, height, width` containing the target + masks. + + Returns: + `List[Tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected labels (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_targets). + """ + indices: List[Tuple[np.array]] = [] + + num_queries = class_queries_logits.shape[1] + + preds_masks = masks_queries_logits + preds_probs = class_queries_logits + # iterate through batch size + for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels): + pred_probs = pred_probs.softmax(-1) + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -pred_probs[:, labels] + + pred_mask = pred_mask[:, None] + target_mask = target_mask[:, None].to(pred_mask.device) + + # all masks share the same set of points for efficient matching! + point_coords = torch.rand(1, self.num_points, 2, device=pred_mask.device) + + # get ground truth labels + target_mask = sample_point( + target_mask, + point_coords.repeat(target_mask.shape[0], 1, 1), + align_corners=False, + ).squeeze(1) + + pred_mask = sample_point( + pred_mask, + point_coords.repeat(pred_mask.shape[0], 1, 1), + align_corners=False, + ).squeeze(1) + + with autocast(enabled=False): + pred_mask = pred_mask.float() + target_mask = target_mask.float() + + # compute the sigmoid ce loss + cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) + # Compute the dice loss + cost_dice = pair_wise_dice_loss(pred_mask, target_mask) + # final cost matrix + cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + cost_matrix = cost_matrix.reshape(num_queries, -1).cpu() + # do the assigmented using the hungarian algorithm in scipy + assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) + indices.append(assigned_indices) + + # It could be stacked in one tensor + matched_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices + ] + return matched_indices + + +class OneFormerLoss(nn.Module): + def __init__( + self, + num_classes: int, + matcher: OneFormerHungarianMatcher, + weight_dict: Dict[str, float], + eos_coef: float, + num_points: int, + oversample_ratio: float, + importance_sample_ratio: float, + contrastive_temperature: float = None, + ): + """ + This class computes the losses using the class predictions, mask predictions and the contrastive queries. + + Oneformer calculates the classification CE loss on the class predictions. Mask predictions are used for + calculating the binary CE loss and dice loss. The contrastive queries are used for calculating the contrastive + loss. + + Args: + num_labels (`int`): + The number of classes. + matcher (`OneFormerHungarianMatcher`): + A torch module that computes the assigments between the predictions and labels. + weight_dict (`Dict[str, float]`): + A dictionary of weights to be applied to the different losses. + eos_coef (`float`): + Weight to apply to the null class. + num_points (`int`): + Number of points to be sampled for dice and mask loss calculations. + oversample_ratio (`float`): + Required for pointwise loss calculation. + importance_sample_ratio (`float`): + Required for pointwise loss calculation. + contrastive_temperature (`float`): + Temperature for scaling the contrastive logits. + """ + requires_backends(self, ["scipy"]) + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + # pointwise mask loss parameters + self.num_points = num_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.contrastive_temperature = contrastive_temperature + if self.contrastive_temperature is not None: + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrastive_temperature)) + + def _max_by_axis(self, the_list: List[List[int]]) -> List[int]: + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]: + # get the maximum size in the batch + max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) + batch_size = len(tensors) + # compute finel size + batch_shape = [batch_size] + max_size + b, _, h, w = batch_shape + # get metadata + dtype = tensors[0].dtype + device = tensors[0].device + padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) + padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device) + # pad the tensors to the size of the biggest one + for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): + padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) + padding_mask[: tensor.shape[1], : tensor.shape[2]] = False + + return padded_tensors, padding_masks + + def loss_contrastive(self, contrastive_queries_logits: Tensor, text_queries: Tensor): + """Compute the query-text contrastive loss. + + Args: + contrastive_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + text_queries (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_contrastive** -- The query-text contrastive loss computed using task-guided queries + and text queries derived from input text list. + """ + + image_queries = contrastive_queries_logits.float() + + # [batch_size, hidden_dim] + image_queries = nn.functional.normalize(image_queries.flatten(1), dim=-1) + text_queries = nn.functional.normalize(text_queries.flatten(1), dim=-1) + + logit_scale = torch.clamp(self.logit_scale.exp(), max=100) + + logits_per_text = torch.matmul(text_queries, image_queries.t()) * logit_scale + logits_per_img = logits_per_text.t() + + loss_img = nn.functional.cross_entropy( + logits_per_img, torch.arange(len(logits_per_img), device=logits_per_text.device) + ) + loss_text = nn.functional.cross_entropy( + logits_per_text, torch.arange(len(logits_per_text), device=logits_per_text.device) + ) + + loss_contrastive = loss_img + loss_text + + losses = {"loss_contrastive": loss_contrastive} + return losses + + def loss_labels( + self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] + ) -> Dict[str, Tensor]: + """Compute the losses related to the labels using cross entropy. + + Args: + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + """ + pred_logits = class_queries_logits + batch_size, num_queries, _ = pred_logits.shape + criterion = nn.CrossEntropyLoss(weight=self.empty_weight) + idx = self._get_predictions_permutation_indices(indices) + + # shape = (batch_size, num_queries) + target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)]) + # shape = (batch_size, num_queries) + target_classes = torch.full( + (batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device + ) + target_classes[idx] = target_classes_o + # permute pred_logits (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) + pred_logits_transposed = pred_logits.transpose(1, 2) + loss_ce = criterion(pred_logits_transposed, target_classes) + losses = {"loss_cross_entropy": loss_ce} + return losses + + def loss_masks( + self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int + ) -> Dict[str, Tensor]: + """Compute the losses related to the masks using focal and dice loss. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + num_masks (`int)`: + The number of masks, used for normalization. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + """ + src_idx = self._get_predictions_permutation_indices(indices) + tgt_idx = self._get_targets_permutation_indices(indices) + # shape (batch_size * num_queries, height, width) + pred_masks = masks_queries_logits[src_idx] + # shape (batch_size, num_queries, height, width) + # pad all and stack the targets to the num_labels dimension + # upsample predictions to the target size, we have to add one dim to use interpolate + target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) + target_masks = target_masks[tgt_idx] + + pred_masks = pred_masks[:, None] + target_masks = target_masks[:, None] + + with torch.no_grad(): + # sample point_coords + point_coords = self.sample_points_using_uncertainty( + pred_masks, + self.calculate_uncertainty, + self.num_points, + self.oversample_ratio, + self.importance_sample_ratio, + ) + # get ground-truth labels + point_labels = sample_point(target_masks, point_coords, align_corners=False).squeeze(1) + + point_logits = sample_point(pred_masks, point_coords, align_corners=False).squeeze(1) + + losses = { + "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), + "loss_dice": dice_loss(point_logits, point_labels, num_masks), + } + + del pred_masks + del target_masks + return losses + + # Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.calculate_uncertainty + def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor: + """ + In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' + for the foreground class in `classes`. + + Args: + logits (`torch.Tensor`): + A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: + the number of foreground classes. The values are logits. + + Returns: + scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most + uncertain locations having the highest uncertainty score. + """ + uncertainty_scores = -(torch.abs(logits)) + return uncertainty_scores + + # Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.sample_points_using_uncertainty + def sample_points_using_uncertainty( + self, + logits: torch.Tensor, + uncertainty_function, + num_points: int, + oversample_ratio: int, + importance_sample_ratio: float, + ) -> torch.Tensor: + """ + This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The + uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit + prediction as input. + + Args: + logits (`float`): + Logit predictions for P points. + uncertainty_function: + A function that takes logit predictions for P points and returns their uncertainties. + num_points (`int`): + The number of points P to sample. + oversample_ratio (`int`): + Oversampling parameter. + importance_sample_ratio (`float`): + Ratio of points that are sampled via importance sampling. + + Returns: + point_coordinates (`torch.Tensor`): + Coordinates for P sampled points. + """ + + num_boxes = logits.shape[0] + num_points_sampled = int(num_points * oversample_ratio) + + # Get random point coordinates + point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) + # Get sampled prediction value for the point coordinates + point_logits = sample_point(logits, point_coordinates, align_corners=False) + # Calculate the uncertainties based on the sampled prediction values of the points + point_uncertainties = uncertainty_function(point_logits) + + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) + idx += shift[:, None] + point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) + + if num_random_points > 0: + point_coordinates = torch.cat( + [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], + dim=1, + ) + return point_coordinates + + def _get_predictions_permutation_indices(self, indices): + # permute predictions following indices + batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + predictions_indices = torch.cat([src for (src, _) in indices]) + return batch_indices, predictions_indices + + def _get_targets_permutation_indices(self, indices): + # permute labels following indices + batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + target_indices = torch.cat([tgt for (_, tgt) in indices]) + return batch_indices, target_indices + + def forward( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + contrastive_queries_logits: Tensor, + mask_labels: List[Tensor], + class_labels: List[Tensor], + text_queries: Tensor, + auxiliary_predictions: Optional[Dict[str, Tensor]] = None, + calculate_contrastive_loss: bool = True, + ) -> Dict[str, Tensor]: + """ + This performs the loss computation. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + contrastive_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. + text_queries (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): + if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], then it contains the logits from the + inner layers of the Detr's Decoder. + calculate_contrastive_loss (`bool`, *optional*, defaults to `True`): + Whether or not to calculate the contrastive loss. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + - **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + - **loss_contrastive** -- The query-text contrstive loss computed using object and text queries. + if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], the dictionary contains addional losses + for each auxiliary predictions. + """ + + # retrieve the matching between the outputs of the last layer and the labels + indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + # compute the average number of target masks for normalization purposes + num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) + # get all the losses + losses: Dict[str, Tensor] = { + **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), + **self.loss_labels(class_queries_logits, class_labels, indices), + } + if calculate_contrastive_loss: + losses = {**losses, **self.loss_contrastive(contrastive_queries_logits, text_queries)} + + # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if auxiliary_predictions is not None: + for idx, aux_outputs in enumerate(auxiliary_predictions): + masks_queries_logits = aux_outputs["masks_queries_logits"] + class_queries_logits = aux_outputs["class_queries_logits"] + loss_dict = self.forward( + masks_queries_logits, + class_queries_logits, + None, + mask_labels, + class_labels, + None, + calculate_contrastive_loss=False, + ) + loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} + losses.update(loss_dict) + + return losses + + def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + Computes the average number of target masks across the batch, for normalization purposes. + """ + num_masks = sum([len(classes) for classes in class_labels]) + num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) + return num_masks_pt + + +@dataclass +class OneFormerTransformerDecoderOutput(BaseModelOutput): + """ + Base class for outputs of the Transformer decoder. This class adds attributes for class predictions, mask + predictions and contrastive logits to BaseModelOutputWithCrossAttentions. + + Args: + object_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`): + Queries representation for the region proposals. + contrastive_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`): + Queries representation for the contrastive loss. + prediction_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`): + Mask predictions from last layer of the transformer decoder. + prediction_class (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`): + Class predictions from last layer of the transformer decoder. + auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*): + Tuple of class and mask predictions from each layer of the transformer decoder. + """ + + object_queries: torch.FloatTensor = None + contrastive_logits: Optional[torch.FloatTensor] = None + prediction_masks: torch.FloatTensor = None + prediction_class: torch.FloatTensor = None + auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None + + +@dataclass +# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoderOutput with Mask2->One +class OneFormerPixelDecoderOutput(ModelOutput): + """ + OneFormer's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns + the mask features and the multiscale features. + + Args: + multi_scale_features (`tuple(torch.FloatTensor)`): + Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height, + width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder. + mask_features (`torch.FloatTensor`): + Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder + Layer. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed + or when `config.output_attentions=True` + """ + + multi_scale_features: Tuple[torch.FloatTensor] = None + mask_features: torch.FloatTensor = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class OneFormerPixelLevelModuleOutput(ModelOutput): + """ + OneFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the + `encoder` and `decoder`. By default, the `encoder` is a Swin/Dinat Backbone and the `decoder` is a Multi-Scale + Deformable Attention based decoder. + + Args: + encoder_features (List of `(torch.FloatTensor)`): + List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + decoder_features (List of `(torch.FloatTensor)`): + List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + decoder_last_feature (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)): + 1/4 scale features from the last Pixel Decoder Layer. + """ + + encoder_features: List[torch.FloatTensor] = None + decoder_features: List[torch.FloatTensor] = None + decoder_last_feature: torch.FloatTensor = None + + +@dataclass +class OneFormerModelOutput(ModelOutput): + """ + Class for outputs of [`OneFormerModel`]. This class returns all the needed hidden states to compute the logits. + + Args: + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. + transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Output object queries from the last layer in the transformer decoder. + transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Contrastive queries from the transformer decoder. + transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`) + Mask Predictions from the last layer in the transformer decoder. + transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`): + Class Predictions from the last layer in the transformer decoder. + transformer_decoder_auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*): + Tuple of class and mask predictions from each layer of the transformer decoder. + text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`) + Text queries derived from the input text list used for calculating contrastive loss during training. + task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`) + 1D task token to condition the queries. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + """ + + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None + transformer_decoder_object_queries: torch.FloatTensor = None + transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None + transformer_decoder_mask_predictions: torch.FloatTensor = None + transformer_decoder_class_predictions: torch.FloatTensor = None + transformer_decoder_auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None + text_queries: Optional[torch.FloatTensor] = None + task_token: torch.FloatTensor = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class OneFormerForUniversalSegmentationOutput(ModelOutput): + """ + Class for outputs of [`OneFormerForUniversalSegmentationOutput`]. + + This output can be directly passed to [`~OneFormerImageProcessor.post_process_semantic_segmentation`] or + [`~OneFormerImageProcessor.post_process_instance_segmentation`] or + [`~OneFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see + [`~OneFormerImageProcessor] for details regarding usage. + + Args: + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*): + List of class and mask predictions from each layer of the transformer decoder. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. + transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Output object queries from the last layer in the transformer decoder. + transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Contrastive queries from the transformer decoder. + transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`) + Mask Predictions from the last layer in the transformer decoder. + transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`): + Class Predictions from the last layer in the transformer decoder. + transformer_decoder_auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*): + List of class and mask predictions from each layer of the transformer decoder. + text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`) + Text queries derived from the input text list used for calculating contrastive loss during training. + task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`) + 1D task token to condition the queries. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + """ + + loss: Optional[torch.FloatTensor] = None + class_queries_logits: torch.FloatTensor = None + masks_queries_logits: torch.FloatTensor = None + auxiliary_predictions: List[Dict[str, torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[List[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None + transformer_decoder_object_queries: torch.FloatTensor = None + transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None + transformer_decoder_mask_predictions: torch.FloatTensor = None + transformer_decoder_class_predictions: torch.FloatTensor = None + transformer_decoder_auxiliary_predictions: Optional[List[Dict[str, torch.FloatTensor]]] = None + text_queries: Optional[torch.FloatTensor] = None + task_token: torch.FloatTensor = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +# Modified from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrFrozenBatchNorm2d with DeformableDetr->OneFormerPixelDecoder +class OneFormerPixelDecoderFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OneFormerPixelDecoderEncoder +class OneFormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" + ) + dim_per_head = embed_dim // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 128 + + self.d_model = embed_dim + self.n_levels = n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = nn.functional.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + # CPU + output = multiscale_deform_attn_core_pytorch(value, spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + + return output, attention_weights + + +class OneFormerPixelDecoderEncoderLayer(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.embed_dim = config.conv_dim + self.self_attn = OneFormerPixelDecoderEncoderMultiscaleDeformableAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + n_levels=3, + n_points=4, + ) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = nn.functional.relu + self.activation_dropout = config.dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim) + self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + self.is_training = config.is_training + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Input to the layer. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Attention mask. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings, to be added to `hidden_states`. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes of the backbone feature maps. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.is_training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if self.is_training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->OneFormerPixelDecoderEncoderOnly +class OneFormerPixelDecoderEncoderOnly(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a + [`OneFormerPixelDecoderEncoderLayer`]. + + The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers. + + Args: + config: OneFormerConfig + """ + + def __init__(self, config: OneFormerConfig): + super().__init__() + + self.config = config + self.dropout = config.dropout + self.layers = nn.ModuleList([OneFormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)]) + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """ + Get reference points for each feature map. Used in decoder. + + Args: + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Valid ratios of each feature map. + device (`torch.device`): + Device on which to create the tensors. + Returns: + `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` + """ + reference_points_list = [] + for lvl, (height, width) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Modified from from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoder with Mask2->One +class OneFormerPixelDecoder(nn.Module): + def __init__(self, config: OneFormerConfig, feature_channels): + super().__init__() + + self.config = config + + # positional encoding + self.position_embedding = OneFormerSinePositionEmbedding(num_pos_feats=config.conv_dim // 2, normalize=True) + self.num_feature_levels = 3 + transformer_in_channels = feature_channels[-self.num_feature_levels :] + self.transformer_feature_strides = config.strides[-self.num_feature_levels :] + self.feature_channels = feature_channels + self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, config.conv_dim)) + + # Create input projection layers + if self.num_feature_levels > 1: + input_projections_list = [] + for in_channels in transformer_in_channels[::-1]: + input_projections_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.conv_dim, kernel_size=1), + nn.GroupNorm(32, config.conv_dim), + ) + ) + self.input_projections = nn.ModuleList(input_projections_list) + else: + self.input_projections = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(transformer_in_channels[-1], config.conv_dim, kernel_size=1), + nn.GroupNorm(32, config.conv_dim), + ) + ] + ) + + self.encoder = OneFormerPixelDecoderEncoderOnly(config) + + self.mask_projection = nn.Conv2d( + config.conv_dim, + config.mask_dim, + kernel_size=1, + stride=1, + padding=0, + ) + + self.common_stride = config.common_stride + + # extra fpn levels + stride = min(self.transformer_feature_strides) + self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) + + lateral_convs = [] + output_convs = [] + + for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]): + lateral_conv = nn.Sequential( + nn.Conv2d( + in_channels, + config.conv_dim, + kernel_size=1, + bias=False, + ), + nn.GroupNorm(32, config.conv_dim), + ) + output_conv = nn.Sequential( + nn.Conv2d( + config.conv_dim, + config.conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.GroupNorm(32, config.conv_dim), + nn.ReLU(), + ) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + def get_valid_ratio(self, mask): + """Get the valid ratio of all feature maps.""" + + _, height, width = mask.shape + valid_height = torch.sum(~mask[:, :, 0], 1) + valid_width = torch.sum(~mask[:, 0, :], 1) + valid_ratio_heigth = valid_height.float() / height + valid_ratio_width = valid_width.float() / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) + return valid_ratio + + def forward( + self, + features, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + sources = [] + position_embeddings_list = [] + for level, source in enumerate(features[::-1][: self.num_feature_levels]): + feats = source.float() + sources.append(self.input_projections[level](feats)) + position_embeddings_list.append(self.position_embedding(feats)) + + masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources] + + # Prepare encoder inputs (by flattening) + source_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + source_flatten.append(source) + mask_flatten.append(mask) + source_flatten = torch.cat(source_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + valid_ratios = valid_ratios.float() + + # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder + # Also provide spatial_shapes, level_start_index and valid_ratios + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=source_flatten, + attention_mask=mask_flatten, + position_embeddings=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + y = encoder_outputs.last_hidden_state + bs = y.shape[0] + + split_size_or_sections = [None] * self.num_feature_levels + for i in range(self.num_feature_levels): + if i < self.num_feature_levels - 1: + split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] + else: + split_size_or_sections[i] = y.shape[1] - level_start_index[i] + y = torch.split(y, split_size_or_sections, dim=1) + + out = [] + multi_scale_features = [] + num_cur_levels = 0 + for i, z in enumerate(y): + out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) + + # append `out` with extra FPN levels + # Reverse feature maps into top-down order (from low to high resolution) + for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]): + feats = feats.float() + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + cur_fpn = lateral_conv(feats) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + nn.functional.interpolate( + out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False + ) + y = output_conv(y) + out.append(y) + + for o in out: + if num_cur_levels < self.num_feature_levels: + multi_scale_features.append(o) + num_cur_levels += 1 + + return OneFormerPixelDecoderOutput( + mask_features=self.mask_projection(out[-1]), + multi_scale_features=multi_scale_features, + attentions=encoder_outputs.attentions, + ) + + +# Modified from from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelLevelModule with Mask2->One +class OneFormerPixelLevelModule(nn.Module): + def __init__(self, config: OneFormerConfig): + """ + Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image + Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel + decoder, generating multi-scale feature maps and pixel embeddings. + + Args: + config ([`OneFormerConfig`]): + The configuration used to instantiate this model. + """ + super().__init__() + backbone_config = config.backbone_config + self.encoder = AutoBackbone.from_config(backbone_config) + self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels) + + def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput: + features: List[Tensor] = self.encoder(pixel_values).feature_maps + decoder_output: OneFormerPixelDecoderOutput = self.decoder(features, output_hidden_states=output_hidden_states) + return OneFormerPixelLevelModuleOutput( + encoder_features=tuple(features), + decoder_features=decoder_output.multi_scale_features, + decoder_last_feature=decoder_output.mask_features, + ) + + +# Modified from transformers.models.detr.modeling_detr.DetrAttention with Detr->OneFormer +class OneFormerAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and + keys (as explained in the DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + key_value_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + hidden_states = hidden_states.permute(1, 0, 2) if hidden_states is not None else None + position_embeddings = position_embeddings.permute(1, 0, 2) if position_embeddings is not None else None + key_value_states = key_value_states.permute(1, 0, 2) if key_value_states is not None else None + key_value_position_embeddings = ( + key_value_position_embeddings.permute(1, 0, 2) if key_value_position_embeddings is not None else None + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size, target_len, embed_dim = hidden_states.size() + + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # add key-value position embeddings to the key value states + if key_value_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(target_len, batch_size * self.num_heads, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights += attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output).permute(1, 0, 2) + + return attn_output, attn_weights_reshaped + + +class OneFormerTransformerDecoderSelfAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.self_attn = OneFormerAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, is_decoder=True) + + self.norm = nn.LayerNorm(embed_dim) + self.dropout = nn.Dropout(dropout) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + output, + output_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2, attention_weights = self.self_attn( + hidden_states=output, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True + ) + output = output + self.dropout(output2) + output = self.norm(output) + + return output, attention_weights + + def forward_pre( + self, + output, + output_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2 = self.norm(output) + output2, attention_weights = self.self_attn( + hidden_states=output2, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True + ) + output = output + self.dropout(output2) + + return output, attention_weights + + def forward( + self, + output, + output_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(output, output_mask, output_key_padding_mask, query_pos) + return self.forward_post(output, output_mask, output_key_padding_mask, query_pos) + + +class OneFormerTransformerDecoderCrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) + + self.norm = nn.LayerNorm(embed_dim) + self.dropout = nn.Dropout(dropout) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + output, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2, attention_weights = self.multihead_attn( + query=self.with_pos_embed(output, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output = output + self.dropout(output2) + output = self.norm(output) + + return output, attention_weights + + def forward_pre( + self, + output, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2 = self.norm(output) + output2, attention_weights = self.multihead_attn( + query=self.with_pos_embed(output2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output = output + self.dropout(output2) + + return output, attention_weights + + def forward( + self, + output, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + + +class OneFormerTransformerDecoderFFNLayer(nn.Module): + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, output): + output2 = self.linear2(self.dropout(self.activation(self.linear1(output)))) + output = output + self.dropout(output2) + output = self.norm(output) + return output + + def forward_pre(self, output): + output2 = self.norm(output) + output2 = self.linear2(self.dropout(self.activation(self.linear1(output2)))) + output = output + self.dropout(output2) + return output + + def forward(self, output): + if self.normalize_before: + return self.forward_pre(output) + return self.forward_post(output) + + +class OneFormerMLPPredictionHead(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): + """ + A classic Multi Layer Perceptron (MLP). + + Args: + input_dim (`int`): + The input dimensions. + hidden_dim (`int`): + The hidden dimensions. + output_dim (`int`): + The output dimensions. + num_layers (int, *optional*, defaults to 3): + The number of layers. + """ + super().__init__() + in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) + out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] + + layers = [] + for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): + layers.append( + PredictionBlock(in_dim, out_dim, activation=nn.ReLU() if i < num_layers - 1 else nn.Identity()) + ) + + self.layers = nn.Sequential(*layers) + + def forward(self, input: Tensor) -> Tensor: + return self.layers(input) + + +# refactored from original implementation +class OneFormerTransformerDecoderLayer(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.embed_dim = config.hidden_dim + self.num_feature_levels = 3 + + self.cross_attn = OneFormerTransformerDecoderCrossAttentionLayer( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=0.0, + normalize_before=config.pre_norm, + ) + + self.self_attn = OneFormerTransformerDecoderSelfAttentionLayer( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=0.0, + normalize_before=config.pre_norm, + ) + + self.ffn = OneFormerTransformerDecoderFFNLayer( + d_model=self.embed_dim, + dim_feedforward=config.dim_feedforward, + dropout=0.0, + normalize_before=config.pre_norm, + ) + + def forward( + self, + index: int, + output: torch.Tensor, + multi_stage_features: List[torch.Tensor], + multi_stage_positional_embeddings: List[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + query_embeddings: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + index (`int`): index of the layer in the Transformer decoder. + output (`torch.FloatTensor`): the object queries of shape `(N, batch, hidden_dim)` + multi_stage_features (`List[torch.Tensor]`): the multi-scale features from the pixel decoder. + multi_stage_positional_embeddings (`List[torch.Tensor]`): + positional embeddings for the multi_stage_features + attention_mask (`torch.FloatTensor`): attention mask for the masked cross attention layer + query_embeddings (`torch.FloatTensor`, *optional*): + position embeddings that are added to the queries and keys in the self-attention layer. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + level_index = index % self.num_feature_levels + attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False + + # Masked Cross Attention + output, cross_attn_weights = self.cross_attn( + output, + multi_stage_features[level_index], + memory_mask=attention_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=multi_stage_positional_embeddings[level_index], + query_pos=query_embeddings, + ) + + # Self Attention + output, self_attn_weights = self.self_attn( + output, + output_mask=None, + output_key_padding_mask=None, + query_pos=query_embeddings, + ) + + # Fully Connected + output = self.ffn(output) + + outputs = (output,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class OneFormerTransformerDecoderQueryTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + output_mask=output_mask, + memory_mask=memory_mask, + output_key_padding_mask=output_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class OneFormerTransformerDecoderQueryTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(output, query_pos) + output2 = self.self_attn(q, k, value=output, attn_mask=output_mask, key_padding_mask=output_key_padding_mask) + output2 = output2[0] + output = output + self.dropout1(output2) + output = self.norm1(output) + output2 = self.multihead_attn( + query=self.with_pos_embed(output, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output2 = output2[0] + output = output + self.dropout2(output2) + output = self.norm2(output) + output2 = self.linear2(self.dropout(self.activation(self.linear1(output)))) + output = output + self.dropout3(output2) + output = self.norm3(output) + return output + + def forward_pre( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2 = self.norm1(output) + q = k = self.with_pos_embed(output2, query_pos) + output2 = self.self_attn(q, k, value=output2, attn_mask=output_mask, key_padding_mask=output_key_padding_mask) + output2 = output2[0] + output = output + self.dropout1(output2) + output2 = self.norm2(output) + output2 = self.multihead_attn( + query=self.with_pos_embed(output2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output2 = output2[0] + output = output + self.dropout2(output2) + output2 = self.norm3(output) + output2 = self.linear2(self.dropout(self.activation(self.linear1(output2)))) + output = output + self.dropout3(output2) + return output + + def forward( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + output, + memory, + output_mask, + memory_mask, + output_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + output, + memory, + output_mask, + memory_mask, + output_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +class OneFormerTransformerDecoderQueryTransformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + decoder_layer = OneFormerTransformerDecoderQueryTransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = OneFormerTransformerDecoderQueryTransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self.d_model = d_model + self.nhead = nhead + + def forward(self, src, mask, query_embed, pos_embed, task_token=None): + batch_size = src.shape[0] + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1) + if mask is not None: + mask = mask.flatten(1) + + if task_token is None: + queries = torch.zeros_like(query_embed) + else: + queries = task_token.repeat(query_embed.shape[0], 1, 1) + + queries = self.decoder(queries, src, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) + return queries.transpose(1, 2) + + +class OneFormerTransformerDecoder(nn.Module): + """ + Transformer decoder + """ + + def __init__(self, in_channels: int, config: OneFormerConfig): + super().__init__() + self.config = config + + self.dropout = config.dropout + self.num_heads = config.num_attention_heads + self.is_training = config.is_training + self.use_task_norm = config.use_task_norm + self.use_auxiliary_loss = config.use_auxiliary_loss + + self.query_transformer = OneFormerTransformerDecoderQueryTransformer( + d_model=config.hidden_dim, + dropout=config.dropout, + nhead=config.num_attention_heads, + dim_feedforward=config.dim_feedforward, + num_decoder_layers=config.query_dec_layers, + normalize_before=config.pre_norm, + return_intermediate_dec=False, + ) + + self.decoder_norm = nn.LayerNorm(config.hidden_dim) + + self.num_feature_levels = 3 + + self.layers = nn.ModuleList( + [OneFormerTransformerDecoderLayer(config) for _ in range(config.decoder_layers - 1)] + ) + + self.query_input_projection = nn.Conv2d(in_channels, config.hidden_dim, kernel_size=1) + + self.class_embed = nn.Linear(config.hidden_dim, config.num_labels + 1) + self.mask_embed = OneFormerMLPPredictionHead( + config.hidden_dim, + config.hidden_dim, + config.mask_dim, + 3, + ) + + def forward( + self, + task_token=None, + multi_stage_features=None, + multi_stage_positional_embeddings=None, + mask_features=None, + query_features=None, + query_embeddings=None, + query_embedder=None, + size_list=None, + output_attentions=None, + ): + if self.use_task_norm: + task_token = self.decoder_norm(task_token) + + object_queries = self.query_transformer( + query_features, + None, + query_embedder.weight[:-1], + self.query_input_projection(mask_features), + task_token if self.use_task_norm else None, + ) + + object_queries = object_queries[0].permute(1, 0, 2) + + queries = torch.cat([object_queries, task_token], dim=0) + + output = queries.clone() + + intermediate_class_predictions = [] + intermediate_mask_predictions = [] + + # prediction heads on learnable query features + outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads( + output, mask_features, attention_mask_target_size=size_list[0] + ) + intermediate_class_predictions.append(outputs_class) + intermediate_mask_predictions.append(outputs_mask) + + attentions = () + + for index, layer in enumerate(self.layers): + layer_outputs = layer( + index=index, + output=output, + multi_stage_features=multi_stage_features, + multi_stage_positional_embeddings=multi_stage_positional_embeddings, + attention_mask=attention_mask, + query_embeddings=query_embeddings, + output_attentions=output_attentions, + ) + + output = layer_outputs[0] + attentions += (layer_outputs[1:],) + + outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads( + output, mask_features, attention_mask_target_size=size_list[(index + 1) % self.num_feature_levels] + ) + intermediate_class_predictions.append(outputs_class) + intermediate_mask_predictions.append(outputs_mask) + + if not len(intermediate_mask_predictions) == len(self.layers) + 1: + raise ValueError( + "Intermediate predictions in the transformer decoder must have the same number of elements as number" + " of layers" + ) + + object_queries = layer_outputs[0].permute(1, 0, 2) + + contrastive_logits = queries.permute(1, 0, 2) + + return OneFormerTransformerDecoderOutput( + object_queries=object_queries, + contrastive_logits=contrastive_logits, + prediction_masks=intermediate_mask_predictions[-1], + prediction_class=intermediate_class_predictions[-1], + auxiliary_predictions=self._get_aux_predictions( + intermediate_class_predictions, intermediate_mask_predictions + ) + if self.use_auxiliary_loss + else None, + attentions=attentions, + ) + + def forward_prediction_heads(self, output, mask_features, attention_mask_target_size): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) + outputs_class = self.class_embed(decoder_output) + mask_embed = self.mask_embed(decoder_output) + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + + attention_mask = nn.functional.interpolate( + outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False + ) + + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attention_mask = ( + attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5 + ).bool() + attention_mask = attention_mask.detach() + + return outputs_class, outputs_mask, attention_mask + + @torch.jit.unused + def _get_aux_predictions(self, outputs_class, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + aux_list = [ + {"class_queries_logits": a, "masks_queries_logits": b} + for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) + ] + return tuple(aux_list) + + +class OneFormerTransformerModule(nn.Module): + """ + The OneFormer's transformer module. + """ + + def __init__(self, in_features: int, config: OneFormerConfig): + super().__init__() + hidden_dim = config.hidden_dim + self.num_feature_levels = 3 + self.position_embedder = OneFormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True) + self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim) + self.input_projections = [] + + for _ in range(self.num_feature_levels): + if in_features != hidden_dim or config.enforce_input_proj: + self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1)) + else: + self.input_projections.append(nn.Sequential()) + + self.decoder = OneFormerTransformerDecoder(in_channels=in_features, config=config) + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + + def forward( + self, + multi_scale_features: List[Tensor], + mask_features: Tensor, + task_token: Tensor, + output_attentions: bool = False, + ) -> OneFormerTransformerDecoderOutput: + if not len(multi_scale_features) == self.num_feature_levels: + raise ValueError( + f"Number of elements in multi_scale_features ({len(multi_scale_features)}) and num_feature_levels" + f" ({self.num_feature_levels}) do not match!" + ) + multi_stage_features = [] + multi_stage_positional_embeddings = [] + size_list = [] + + for i in range(self.num_feature_levels): + size_list.append(multi_scale_features[i].shape[-2:]) + multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) + multi_stage_features.append( + self.input_projections[i](multi_scale_features[i]).flatten(2) + + self.level_embed.weight[i][None, :, None] + ) + + # flatten NxCxHxW to HWxNxC + multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) + multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) + + _, batch_size, _ = multi_stage_features[0].shape + + # QxNxC + query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) + task_token = task_token.unsqueeze(0) + + query_features = self.position_embedder(mask_features, None) + + return self.decoder( + task_token=task_token, + multi_stage_features=multi_stage_features, + multi_stage_positional_embeddings=multi_stage_positional_embeddings, + mask_features=mask_features, + query_features=query_features, + query_embeddings=query_embeddings, + query_embedder=self.queries_embedder, + size_list=size_list, + output_attentions=output_attentions, + ) + + +# Copied from transformers.models.maskformer.modeling_maskformer.MaskFormerSinePositionEmbedding with Mask->One +class OneFormerSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# Copied from transformers.models.maskformer.modeling_maskformer.PredictionBlock +class PredictionBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None: + super().__init__() + self.layers = [nn.Linear(in_dim, out_dim), activation] + # Maintain submodule indexing as if part of a Sequential block + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class OneFormerTextMapperAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, k, v): + batch_size, q_sequence_length, num_channels = q.shape + if not k.shape == v.shape: + raise ValueError(f"keys ({list(k.shape)}) and values ({list(v.shape)}) have different shapes!") + batch_size, k_sequence_length, num_channels = k.shape + q = self.q_proj(q).reshape(batch_size, q_sequence_length, self.num_heads, num_channels // self.num_heads) + k = self.k_proj(k).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads) + v = self.v_proj(v).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads) + + attn = torch.einsum("bnkc,bmkc->bknm", q, k) * self.scale + + attn = attn.softmax(dim=-1) + + output = torch.einsum("bknm,bmkc->bnkc", attn, v).reshape(batch_size, q_sequence_length, num_channels) + + output = self.proj(output) + output = self.proj_drop(output) + return output + + +class OneFormerTextTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dropout=0.1, + ): + super().__init__() + self.self_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout) + self.cross_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 4, d_model) + ) + + def forward(self, hidden_state, mem): + q = k = v = self.norm1(hidden_state) + hidden_state = hidden_state + self.self_attn(q, k, v) + q = self.norm2(hidden_state) + hidden_state = hidden_state + self.cross_attn(q, mem, mem) + hidden_state = hidden_state + self.dropout(self.mlp(self.norm3(hidden_state))) + return hidden_state + + +class OneFormerTextContextDecoder(nn.Module): + def __init__( + self, transformer_width=256, transformer_heads=4, transformer_layers=6, visual_dim=1024, dropout=0.1, **kwargs + ): + super().__init__() + + self.memory_proj = nn.Sequential( + nn.LayerNorm(visual_dim), + nn.Linear(visual_dim, transformer_width), + nn.LayerNorm(transformer_width), + ) + + self.text_proj = nn.Sequential( + nn.LayerNorm(visual_dim), + nn.Linear(visual_dim, transformer_width), + ) + + self.decoder = nn.ModuleList( + [ + OneFormerTextTransformerDecoderLayer(transformer_width, transformer_heads, dropout) + for _ in range(transformer_layers) + ] + ) + + self.out_proj = nn.Sequential(nn.LayerNorm(transformer_width), nn.Linear(transformer_width, visual_dim)) + + def forward(self, text, visual): + visual = self.memory_proj(visual) + hidden_state = self.text_proj(text) + + for layer in self.decoder: + hidden_state = layer(hidden_state, visual) + + return self.out_proj(hidden_state) + + +class OneFormerTextMLP(nn.Module): + def __init__( + self, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, + output_size: Optional[int] = None, + ): + super().__init__() + self.activation_fn = ACT2FN["quick_gelu"] + hidden_size = hidden_size + intermediate_size = intermediate_size + output_size = output_size + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, output_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class OneFormerTextTransformerLayer(nn.Module): + def __init__(self, width: int, heads: int, attn_mask: torch.Tensor): + super().__init__() + self.self_attn = nn.MultiheadAttention(width, heads) + self.layer_norm1 = nn.LayerNorm(width) + self.mlp = OneFormerTextMLP(width, width * 4, width) + self.layer_norm2 = nn.LayerNorm(width) + self.attn_mask = attn_mask + + def forward( + self, + hidden_states: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states, + hidden_states, + hidden_states, + need_weights=False, + key_padding_mask=key_padding_mask, + )[0] + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class OneFormerTextTransformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False): + super().__init__() + self.width = width + self.num_layers = layers + self.layers = nn.Sequential(*[OneFormerTextTransformerLayer(width, heads, attn_mask) for _ in range(layers)]) + self.use_checkpoint = use_checkpoint + + def forward(self, hidden_states: torch.Tensor): + for layer in self.layers: + if self.use_checkpoint: + hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + else: + hidden_states = layer(hidden_states) + return hidden_states + + +class OneFormerTextEncoder(nn.Module): + def __init__( + self, + context_length: int, + width: int, + layers: int, + vocab_size, + use_checkpoint=False, + ): + super().__init__() + heads = width // 64 + self.context_length = context_length + self.width = width + self.transformer = OneFormerTextTransformer( + width=width, + layers=layers, + heads=heads, + attn_mask=self.build_attention_mask(), + use_checkpoint=use_checkpoint, + ) + + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.ln_final = nn.LayerNorm(width) + self.token_embedding = nn.Embedding(vocab_size, width) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text): + hidden_state = self.token_embedding(text) + hidden_state = hidden_state + self.positional_embedding + hidden_state = hidden_state.permute(1, 0, 2) + hidden_state = self.transformer(hidden_state) + hidden_state = hidden_state.permute(1, 0, 2) + hidden_state = self.ln_final(hidden_state) + hidden_state = hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)] + + return hidden_state + + +class OneFormerTextMapper(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.text_encoder = OneFormerTextEncoder( + context_length=config.text_encoder_context_length, + width=config.text_encoder_width, + layers=config.text_encoder_num_layers, + vocab_size=config.text_encoder_vocab_size, + ) + + self.text_projector = OneFormerMLPPredictionHead( + config.text_encoder_width, + config.hidden_dim, + config.hidden_dim, + config.text_encoder_proj_layers, + ) + if config.text_encoder_n_ctx > 0: + self.prompt_ctx = nn.Embedding( + config.text_encoder_n_ctx, + config.text_encoder_width, + ) + else: + self.prompt_ctx = None + + def forward( + self, + inputs: Tensor, + ) -> Tensor: + text_queries = self.encode_text(inputs) + + return text_queries + + def encode_text(self, text): + if text.ndim is None: + raise ValueError("text must not be NoneType") + if text.ndim not in [2, 3]: + raise ValueError("Number of dimensions in text must be 2 or 3") + squeeze_dim = False + num_text = 1 + if text.ndim == 3: + num_text = text.shape[1] + batch_size, num_text, hidden_dim = text.shape + text = text.reshape(batch_size * num_text, hidden_dim) + squeeze_dim = True + + # [batch_size, num_channels] + encoded_text = self.text_encoder(text) + + text_queries = self.text_projector(encoded_text) + + if squeeze_dim: + _, hidden_dim = text_queries.shape + text_queries = text_queries.reshape(batch_size, num_text, hidden_dim) + if self.prompt_ctx is not None: + text_queries_ctx = self.prompt_ctx.weight.unsqueeze(0).repeat(text_queries.shape[0], 1, 1) + text_queries = torch.cat([text_queries, text_queries_ctx], dim=1) + + return text_queries + + +class OneFormerTaskModel(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.task_mlp = OneFormerMLPPredictionHead( + config.task_seq_len, + config.hidden_dim, + config.hidden_dim, + 2, + ) + + def forward(self, inputs: Tensor) -> Tensor: + task_tokens = self.task_mlp(inputs.float()) + return task_tokens + + +ONEFORMER_START_DOCSTRING = r""" + This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a + regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + + Parameters: + config ([`OneFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ONEFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`OneFormerProcessor`]. See + [`OneFormerProcessor.__call__`] for details. + task_inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Task inputs. Task inputs can be obtained using [`OneFormerImageProcessor`]. See + [`OneFormerProcessor.__call__`] for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of Detr's decoder attention layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~OneFormerModelOutput`] instead of a plain tuple. +""" + + +class OneFormerPreTrainedModel(PreTrainedModel): + config_class = OneFormerConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module: nn.Module): + xavier_std = self.config.init_xavier_std + std = self.config.init_std + if isinstance(module, OneFormerTransformerModule): + if module.input_projections is not None: + for input_projection in module.input_projections: + if not isinstance(input_projection, nn.Sequential): + nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std) + nn.init.constant_(input_projection.bias, 0) + elif isinstance(module, OneFormerTransformerDecoder): + nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) + nn.init.constant_(module.query_input_projection.bias, 0) + elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + elif isinstance(module, OneFormerPixelDecoderEncoderOnly): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + elif isinstance(module, OneFormerPixelDecoder): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + nn.init.normal_(module.level_embed, std=0) + elif isinstance(module, OneFormerTransformerDecoderSelfAttentionLayer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerTransformerDecoderCrossAttentionLayer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerTransformerDecoderFFNLayer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerTransformerDecoderQueryTransformer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerPixelLevelModule): + for submodule in module.modules(): + if isinstance(submodule, (nn.Conv2d, nn.Linear)): + submodule.weight.data.normal_(mean=0.0, std=std) + if submodule.bias is not None: + submodule.bias.data.zero_() + elif isinstance(module, OneFormerTextContextDecoder): + for submodule in module.modules(): + if isinstance(submodule, nn.Linear): + nn.init.trunc_normal_(submodule.weight, std=0.02) + if isinstance(submodule, nn.Linear) and submodule.bias is not None: + nn.init.constant_(submodule.bias, 0) + elif isinstance(submodule, nn.LayerNorm): + nn.init.constant_(submodule.bias, 0) + nn.init.constant_(submodule.weight, 1.0) + elif isinstance(module, OneFormerTextTransformer): + proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5) + attn_std = module.width**-0.5 + fc_std = (2 * module.width) ** -0.5 + for layer in module.layers: + nn.init.normal_(layer.self_attn.in_proj_weight, std=attn_std) + nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std) + nn.init.normal_(layer.mlp.fc1.weight, std=fc_std) + nn.init.normal_(layer.mlp.fc2.weight, std=proj_std) + elif isinstance(module, OneFormerTextEncoder): + nn.init.normal_(module.token_embedding.weight, std=0.02) + nn.init.normal_(module.positional_embedding, std=0.01) + if hasattr(module, "reference_points"): + nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) + nn.init.constant_(module.reference_points.bias.data, 0.0) + elif isinstance(module, OneFormerTaskModel): + for submodule in module.modules(): + if isinstance(module, OneFormerMLPPredictionHead): + for submodule in module.modules(): + if isinstance(submodule, nn.Linear): + nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) + nn.init.constant_(submodule.bias, 0) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.MultiheadAttention): + module.in_proj_weight.data.normal_(mean=0.0, std=std) + module.in_proj_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@add_start_docstrings( + "The bare OneFormer Model outputting raw hidden-states without any specific head on top.", + ONEFORMER_START_DOCSTRING, +) +class OneFormerModel(OneFormerPreTrainedModel): + main_input_name = ["pixel_values", "task_inputs"] + + def __init__(self, config: OneFormerConfig): + super().__init__(config) + self.pixel_level_module = OneFormerPixelLevelModule(config) + self.transformer_module = OneFormerTransformerModule(in_features=config.conv_dim, config=config) + self.task_encoder = OneFormerTaskModel(config) + self.is_training = config.is_training + + if self.is_training: + self.text_mapper = OneFormerTextMapper(config) + else: + self.text_mapper = None + + self.post_init() + + @add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OneFormerModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + task_inputs: Tensor, + text_inputs: Optional[Tensor] = None, + pixel_mask: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> OneFormerModelOutput: + r""" + Returns: + `OneFormerModelOutput` + Example: + + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import OneFormerProcessor, OneFormerModel + + >>> # download texting image + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # load processor for preprocessing the inputs + >>> processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + >>> model = OneFormerModel.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + >>> inputs = processor(image, ["semantic"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> mask_predictions = outputs.transformer_decoder_mask_predictions + >>> class_predictions = outputs.transformer_decoder_class_predictions + + >>> f"👉 Mask Predictions Shape: {list(mask_predictions.shape)}, Class Predictions Shape: {list(class_predictions.shape)}" + '👉 Mask Predictions Shape: [1, 150, 128, 176], Class Predictions Shape: [1, 150, 151]' + ```""" + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states) + + multi_scale_features = pixel_level_module_output.decoder_features + mask_features = pixel_level_module_output.decoder_last_feature + + task_token = self.task_encoder(task_inputs) + + if self.is_training: + text_queries = self.text_mapper(text_inputs) + else: + text_queries = None + + transformer_module_output = self.transformer_module( + multi_scale_features=multi_scale_features, + mask_features=mask_features, + task_token=task_token, + output_attentions=output_attentions, + ) + + queries = transformer_module_output.object_queries + + encoder_hidden_states = None + pixel_decoder_hidden_states = None + transformer_decoder_hidden_states = None + + if output_hidden_states: + encoder_hidden_states = pixel_level_module_output.encoder_features + pixel_decoder_hidden_states = (pixel_level_module_output.decoder_last_feature,) + for f in pixel_level_module_output.decoder_features: + pixel_decoder_hidden_states += (f,) + transformer_decoder_hidden_states = transformer_module_output.auxiliary_predictions + + output = OneFormerModelOutput( + encoder_hidden_states=encoder_hidden_states, + pixel_decoder_hidden_states=pixel_decoder_hidden_states, + transformer_decoder_hidden_states=transformer_decoder_hidden_states, + transformer_decoder_object_queries=queries, + transformer_decoder_contrastive_queries=transformer_module_output.contrastive_logits, + transformer_decoder_mask_predictions=transformer_module_output.prediction_masks, + transformer_decoder_class_predictions=transformer_module_output.prediction_class, + transformer_decoder_auxiliary_predictions=transformer_module_output.auxiliary_predictions, + text_queries=text_queries, + task_token=task_token, + attentions=transformer_module_output.attentions, + ) + + if not return_dict: + output = tuple(v for v in output.values()) + + return output + + +@add_start_docstrings( + "OneFormer Model for instance, semantic and panoptic image segmentation.", + ONEFORMER_START_DOCSTRING, +) +class OneFormerForUniversalSegmentation(OneFormerPreTrainedModel): + main_input_name = ["pixel_values", "task_inputs"] + + def __init__(self, config: OneFormerConfig): + super().__init__(config) + self.model = OneFormerModel(config) + + self.matcher = OneFormerHungarianMatcher( + cost_class=config.class_weight, + cost_dice=config.dice_weight, + cost_mask=config.mask_weight, + num_points=config.train_num_points, + ) + + self.weight_dict: Dict[str, float] = { + "loss_cross_entropy": config.class_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + "loss_contrastive": config.contrastive_weight, + } + + self.criterion = OneFormerLoss( + num_classes=config.num_labels, + matcher=self.matcher, + weight_dict=self.weight_dict, + eos_coef=config.no_object_weight, + num_points=config.train_num_points, + oversample_ratio=config.oversample_ratio, + importance_sample_ratio=config.importance_sample_ratio, + contrastive_temperature=config.contrastive_temperature, + ) + + self.post_init() + + def get_loss_dict( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + contrastive_queries_logits: Tensor, + mask_labels: Tensor, + class_labels: Tensor, + text_queries: Tensor, + auxiliary_predictions: Dict[str, Tensor], + calculate_contrastive_loss: bool, + ) -> Dict[str, Tensor]: + loss_dict: Dict[str, Tensor] = self.criterion( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + contrastive_queries_logits=contrastive_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + text_queries=text_queries, + auxiliary_predictions=auxiliary_predictions, + calculate_contrastive_loss=calculate_contrastive_loss, + ) + + # weight each loss by `self.weight_dict[]` including auxiliary losses + for key, weight in self.weight_dict.items(): + for loss_key, loss in loss_dict.items(): + if key in loss_key: + loss *= weight + + return loss_dict + + def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: + return sum(loss_dict.values()) + + @add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OneFormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + task_inputs: Tensor, + text_inputs: Optional[Tensor] = None, + mask_labels: Optional[List[Tensor]] = None, + class_labels: Optional[List[Tensor]] = None, + pixel_mask: Optional[Tensor] = None, + output_auxiliary_logits: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> OneFormerForUniversalSegmentationOutput: + r""" + text_inputs (`List[torch.Tensor]`, *optional*): + Tensor fof shape `(num_queries, sequence_length)` to be fed to a model + mask_labels (`List[torch.Tensor]`, *optional*): + List of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`List[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + + Returns: + `OneFormerUniversalSegmentationOutput` + Example: + + Universal segmentation example: + + ```python + >>> from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> # load OneFormer fine-tuned on ADE20k for universal segmentation + >>> processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + >>> model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + + >>> url = ( + ... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" + ... ) + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # Semantic Segmentation + >>> inputs = processor(image, ["semantic"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to feature_extractor for semantic postprocessing + >>> predicted_semantic_map = feature_extractor.post_process_semantic_segmentation( + ... outputs, target_sizes=[image.size[::-1]] + ... )[0] + >>> f"👉 Semantic Predictions Shape: {list(predicted_semantic_map.shape)}" + '👉 Semantic Predictions Shape: [512, 683]' + + >>> # Instance Segmentation + >>> inputs = processor(image, ["instance"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to feature_extractor for instance postprocessing + >>> predicted_instance_map = feature_extractor.post_process_instance_segmentation( + ... outputs, target_sizes=[image.size[::-1]] + ... )[0]["segmentation"] + >>> f"👉 Instance Predictions Shape: {list(predicted_instance_map.shape)}" + '👉 Instance Predictions Shape: [512, 683]' + + >>> # Panoptic Segmentation + >>> inputs = processor(image, ["panoptic"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to feature_extractor for panoptic postprocessing + >>> predicted_panoptic_map = feature_extractor.post_process_panoptic_segmentation( + ... outputs, target_sizes=[image.size[::-1]] + ... )[0]["segmentation"] + >>> f"👉 Panoptic Predictions Shape: {list(predicted_panoptic_map.shape)}" + '👉 Panoptic Predictions Shape: [512, 683]' + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values=pixel_values, + task_inputs=task_inputs, + text_inputs=text_inputs, + pixel_mask=pixel_mask, + output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, + output_attentions=output_attentions, + return_dict=True, + ) + + loss, loss_dict, auxiliary_predictions = None, None, None + + class_queries_logits = outputs.transformer_decoder_class_predictions + masks_queries_logits = outputs.transformer_decoder_mask_predictions + contrastive_queries_logits = outputs.transformer_decoder_contrastive_queries + auxiliary_predictions = outputs.transformer_decoder_auxiliary_predictions + text_queries = outputs.text_queries + + if mask_labels is not None and class_labels is not None: + loss_dict: Dict[str, Tensor] = self.get_loss_dict( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + contrastive_queries_logits=contrastive_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + text_queries=text_queries, + auxiliary_predictions=auxiliary_predictions, + calculate_contrastive_loss=self.config.contrastive_temperature is not None, + ) + loss = self.get_loss(loss_dict) + + output_auxiliary_logits = ( + self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits + ) + if not output_auxiliary_logits: + auxiliary_predictions = None + + output = OneFormerForUniversalSegmentationOutput( + class_queries_logits=class_queries_logits, + masks_queries_logits=masks_queries_logits, + auxiliary_predictions=auxiliary_predictions, + loss=loss, + **outputs, + ) + + if not return_dict: + output = tuple(v for v in output.values()) + if loss is not None: + output = ((loss)) + output + return output diff --git a/src/transformers/models/oneformer/processing_oneformer.py b/src/transformers/models/oneformer/processing_oneformer.py new file mode 100644 index 000000000000..bc392a77c14d --- /dev/null +++ b/src/transformers/models/oneformer/processing_oneformer.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for OneFormer +""" + +from typing import List + +from transformers.utils import is_torch_available + +from ...processing_utils import ProcessorMixin + + +if is_torch_available(): + import torch + + +class OneFormerProcessor(ProcessorMixin): + r""" + Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and + [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into a single processor that inherits both the image processor and + tokenizer functionalities. + + Args: + image_processor ([`OneFormerImageProcessor`]): + The image processor is a required input. + tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]): + The tokenizer is a required input. + max_seq_len (`int`, *optional*, defaults to 77)): + Sequence length for input text list. + task_seq_len (`int`, *optional*, defaults to 77): + Sequence length for input task token. + """ + attributes = ["image_processor", "tokenizer"] + image_processor_class = "OneFormerImageProcessor" + tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + + def __init__( + self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs + ): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + self.max_seq_length = max_seq_length + self.task_seq_length = task_seq_length + + super().__init__(image_processor, tokenizer) + + def _preprocess_text(self, text_list=None, max_length=77): + if text_list is None: + raise ValueError("tokens cannot be None.") + + tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True) + + attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"] + + token_inputs = [] + for attn_mask, input_id in zip(attention_masks, input_ids): + token = torch.tensor(attn_mask) * torch.tensor(input_id) + token_inputs.append(token.unsqueeze(0)) + + token_inputs = torch.cat(token_inputs, dim=0) + return token_inputs + + def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs): + """ + Main method to prepare for the model one or several task input(s) and image(s). This method forwards the + `task_inputs` and `kwargs` arguments to CLIPTokenizer's [`~CLIPTokenizer.__call__`] if `task_inputs` is not + `None` to encode. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + OneFormerImageProcessor's [`~OneFormerImageProcessor.__call__`] if `images` is not `None`. Please refer to the + doctsring of the above two methods for more information. + + Args: + task_inputs (`str`, `List[str]`): + The sequence or batch of task_inputs sequences to be encoded. Each sequence can be a string or a list + of strings of the template "the task is {task}". + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + + (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **task_inputs** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if task_inputs is None: + raise ValueError("You have to specify the task_input. Found None.") + elif images is None: + raise ValueError("You have to specify the image. Found None.") + + if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs): + raise ValueError("task_inputs must be semantic, instance, or panoptic.") + + encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **kwargs) + + if isinstance(task_inputs, str): + task_inputs = [task_inputs] + + if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs): + task_token_inputs = [] + for task in task_inputs: + task_input = f"the task is {task}" + task_token_inputs.append(task_input) + encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) + else: + raise TypeError("Task Inputs should be a string or a list of strings.") + + if hasattr(encoded_inputs, "text_inputs"): + texts_list = encoded_inputs.text_inputs + + text_inputs = [] + for texts in texts_list: + text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length) + text_inputs.append(text_input_list.unsqueeze(0)) + + encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0) + + return encoded_inputs + + def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.encode_inputs`] and then tokenizes the + task_inputs. Please refer to the docstring of this method for more information. + """ + + if task_inputs is None: + raise ValueError("You have to specify the task_input. Found None.") + elif images is None: + raise ValueError("You have to specify the image. Found None.") + + if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs): + raise ValueError("task_inputs must be semantic, instance, or panoptic.") + + encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs) + + if isinstance(task_inputs, str): + task_inputs = [task_inputs] + + if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs): + task_token_inputs = [] + for task in task_inputs: + task_input = f"the task is {task}" + task_token_inputs.append(task_input) + encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) + else: + raise TypeError("Task Inputs should be a string or a list of strings.") + + if hasattr(encoded_inputs, "text_inputs"): + texts_list = encoded_inputs.text_inputs + + text_inputs = [] + for texts in texts_list: + text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length) + text_inputs.append(text_input_list.unsqueeze(0)) + + encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0) + + return encoded_inputs + + def post_process_semantic_segmentation(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.post_process_semantic_segmentation`]. + Please refer to the docstring of this method for more information. + """ + return self.image_processor.post_process_semantic_segmentation(*args, **kwargs) + + def post_process_instance_segmentation(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.post_process_instance_segmentation`]. + Please refer to the docstring of this method for more information. + """ + return self.image_processor.post_process_instance_segmentation(*args, **kwargs) + + def post_process_panoptic_segmentation(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.post_process_panoptic_segmentation`]. + Please refer to the docstring of this method for more information. + """ + return self.image_processor.post_process_panoptic_segmentation(*args, **kwargs) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1ae0f4f2f3a1..4b68242146fd 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4242,6 +4242,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class OneFormerForUniversalSegmentation(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OneFormerModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OneFormerPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 9237d637c3df..18b1f07ef264 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -318,6 +318,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class OneFormerImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class OwlViTFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/oneformer/__init__.py b/tests/models/oneformer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/oneformer/test_image_processing_oneformer.py b/tests/models/oneformer/test_image_processing_oneformer.py new file mode 100644 index 000000000000..f34ae080cf96 --- /dev/null +++ b/tests/models/oneformer/test_image_processing_oneformer.py @@ -0,0 +1,456 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import unittest + +import numpy as np + +from huggingface_hub import hf_hub_download +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + + if is_vision_available(): + from transformers import OneFormerImageProcessor + from transformers.models.oneformer.image_processing_oneformer import binary_mask_to_rle + from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput + +if is_vision_available(): + from PIL import Image + + +def prepare_metadata(class_info_file, repo_path="shi-labs/oneformer_demo"): + with open(hf_hub_download(repo_path, class_info_file, repo_type="dataset"), "r") as f: + class_info = json.load(f) + metadata = {} + class_names = [] + thing_ids = [] + for key, info in class_info.items(): + metadata[key] = info["name"] + class_names.append(info["name"]) + if info["isthing"]: + thing_ids.append(int(key)) + metadata["thing_ids"] = thing_ids + metadata["class_names"] = class_names + return metadata + + +class OneFormerImageProcessorTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + size=None, + do_resize=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + num_labels=10, + reduce_labels=False, + ignore_index=255, + repo_path="shi-labs/oneformer_demo", + class_info_file="ade20k_panoptic.json", + num_text=10, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = {"shortest_edge": 32, "longest_edge": 1333} if size is None else size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.class_info_file = class_info_file + self.metadata = prepare_metadata(class_info_file, repo_path) + self.num_text = num_text + self.repo_path = repo_path + + # for the post_process_functions + self.batch_size = 2 + self.num_queries = 10 + self.num_classes = 10 + self.height = 3 + self.width = 4 + self.num_labels = num_labels + self.reduce_labels = reduce_labels + self.ignore_index = ignore_index + + def prepare_feat_extract_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "num_labels": self.num_labels, + "reduce_labels": self.reduce_labels, + "ignore_index": self.ignore_index, + "class_info_file": self.class_info_file, + "metadata": self.metadata, + "num_text": self.num_text, + } + + def get_expected_values(self, image_inputs, batched=False): + """ + This function computes the expected height and width when providing images to OneFormerImageProcessor, + assuming do_resize is set to True with a scalar size. + """ + if not batched: + image = image_inputs[0] + if isinstance(image, Image.Image): + w, h = image.size + else: + h, w = image.shape[1], image.shape[2] + if w < h: + expected_height = int(self.size["shortest_edge"] * h / w) + expected_width = self.size["shortest_edge"] + elif w > h: + expected_height = self.size["shortest_edge"] + expected_width = int(self.size["shortest_edge"] * w / h) + else: + expected_height = self.size["shortest_edge"] + expected_width = self.size["shortest_edge"] + + else: + expected_values = [] + for image in image_inputs: + expected_height, expected_width = self.get_expected_values([image]) + expected_values.append((expected_height, expected_width)) + expected_height = max(expected_values, key=lambda item: item[0])[0] + expected_width = max(expected_values, key=lambda item: item[1])[1] + + return expected_height, expected_width + + def get_fake_oneformer_outputs(self): + return OneFormerForUniversalSegmentationOutput( + # +1 for null class + class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)), + masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)), + ) + + +@require_torch +@require_vision +class OneFormerImageProcessingTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + image_processing_class = OneFormerImageProcessor if (is_vision_available() and is_torch_available()) else None + # only for test_feat_extracttion_common.test_feat_extract_to_json_string + feature_extraction_class = image_processing_class + + def setUp(self): + self.image_processing_tester = OneFormerImageProcessorTester(self) + + @property + def feat_extract_dict(self): + return self.image_processing_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + image_processor = self.image_processing_class(**self.feat_extract_dict) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "ignore_index")) + self.assertTrue(hasattr(image_processor, "class_info_file")) + self.assertTrue(hasattr(image_processor, "num_text")) + self.assertTrue(hasattr(image_processor, "repo_path")) + self.assertTrue(hasattr(image_processor, "metadata")) + self.assertTrue(hasattr(image_processor, "reduce_labels")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.image_processing_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs, batched=True) + + encoded_images = image_processor( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ).pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.image_processing_tester.batch_size, + self.image_processing_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_call_numpy(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.image_processing_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs, batched=True) + + encoded_images = image_processor( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ).pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.image_processing_tester.batch_size, + self.image_processing_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_call_pytorch(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.image_processing_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs, batched=True) + + encoded_images = image_processor( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ).pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.image_processing_tester.batch_size, + self.image_processing_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_equivalence_pad_and_create_pixel_mask(self): + # Initialize image_processors + image_processor_1 = self.image_processing_class(**self.feat_extract_dict) + image_processor_2 = self.image_processing_class( + do_resize=False, + do_normalize=False, + do_rescale=False, + num_labels=self.image_processing_tester.num_classes, + class_info_file="ade20k_panoptic.json", + num_text=self.image_processing_tester.num_text, + repo_path="shi-labs/oneformer_demo", + ) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test whether the method "pad_and_return_pixel_mask" and calling the image processor return the same tensors + encoded_images_with_method = image_processor_1.encode_inputs( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ) + encoded_images = image_processor_2(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt") + + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4) + ) + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4) + ) + + def comm_get_image_processor_inputs( + self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np" + ): + image_processor = self.image_processing_class(**self.feat_extract_dict) + # prepare image and target + num_labels = self.image_processing_tester.num_labels + annotations = None + instance_id_to_semantic_id = None + image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False) + if with_segmentation_maps: + high = num_labels + if is_instance_map: + labels_expanded = list(range(num_labels)) * 2 + instance_id_to_semantic_id = { + instance_id: label_id for instance_id, label_id in enumerate(labels_expanded) + } + annotations = [ + np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs + ] + if segmentation_type == "pil": + annotations = [Image.fromarray(annotation) for annotation in annotations] + + inputs = image_processor( + image_inputs, + ["semantic"] * len(image_inputs), + annotations, + return_tensors="pt", + instance_id_to_semantic_id=instance_id_to_semantic_id, + pad_and_return_pixel_mask=True, + ) + + return inputs + + def test_init_without_params(self): + pass + + def test_call_with_segmentation_maps(self): + def common(is_instance_map=False, segmentation_type=None): + inputs = self.comm_get_image_processor_inputs( + with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type + ) + + mask_labels = inputs["mask_labels"] + class_labels = inputs["class_labels"] + pixel_values = inputs["pixel_values"] + text_inputs = inputs["text_inputs"] + + # check the batch_size + for mask_label, class_label, text_input in zip(mask_labels, class_labels, text_inputs): + self.assertEqual(mask_label.shape[0], class_label.shape[0]) + # this ensure padding has happened + self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:]) + self.assertEqual(len(text_input), self.image_processing_tester.num_text) + + common() + common(is_instance_map=True) + common(is_instance_map=False, segmentation_type="pil") + common(is_instance_map=True, segmentation_type="pil") + + def test_binary_mask_to_rle(self): + fake_binary_mask = np.zeros((20, 50)) + fake_binary_mask[0, 20:] = 1 + fake_binary_mask[1, :15] = 1 + fake_binary_mask[5, :10] = 1 + + rle = binary_mask_to_rle(fake_binary_mask) + self.assertEqual(len(rle), 4) + self.assertEqual(rle[0], 21) + self.assertEqual(rle[1], 45) + + def test_post_process_semantic_segmentation(self): + fature_extractor = self.image_processing_class( + num_labels=self.image_processing_tester.num_classes, + max_seq_length=77, + task_seq_length=77, + class_info_file="ade20k_panoptic.json", + num_text=self.image_processing_tester.num_text, + repo_path="shi-labs/oneformer_demo", + ) + outputs = self.image_processing_tester.get_fake_oneformer_outputs() + + segmentation = fature_extractor.post_process_semantic_segmentation(outputs) + + self.assertEqual(len(segmentation), self.image_processing_tester.batch_size) + self.assertEqual( + segmentation[0].shape, + ( + self.image_processing_tester.height, + self.image_processing_tester.width, + ), + ) + + target_sizes = [(1, 4) for i in range(self.image_processing_tester.batch_size)] + segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) + + self.assertEqual(segmentation[0].shape, target_sizes[0]) + + def test_post_process_instance_segmentation(self): + image_processor = self.image_processing_class( + num_labels=self.image_processing_tester.num_classes, + max_seq_length=77, + task_seq_length=77, + class_info_file="ade20k_panoptic.json", + num_text=self.image_processing_tester.num_text, + repo_path="shi-labs/oneformer_demo", + ) + outputs = self.image_processing_tester.get_fake_oneformer_outputs() + segmentation = image_processor.post_process_instance_segmentation(outputs, threshold=0) + + self.assertTrue(len(segmentation) == self.image_processing_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual( + el["segmentation"].shape, (self.image_processing_tester.height, self.image_processing_tester.width) + ) + + def test_post_process_panoptic_segmentation(self): + image_processor = self.image_processing_class( + num_labels=self.image_processing_tester.num_classes, + max_seq_length=77, + task_seq_length=77, + class_info_file="ade20k_panoptic.json", + num_text=self.image_processing_tester.num_text, + repo_path="shi-labs/oneformer_demo", + ) + outputs = self.image_processing_tester.get_fake_oneformer_outputs() + segmentation = image_processor.post_process_panoptic_segmentation(outputs, threshold=0) + + self.assertTrue(len(segmentation) == self.image_processing_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual( + el["segmentation"].shape, (self.image_processing_tester.height, self.image_processing_tester.width) + ) diff --git a/tests/models/oneformer/test_modeling_oneformer.py b/tests/models/oneformer/test_modeling_oneformer.py new file mode 100644 index 000000000000..7fce79bbf605 --- /dev/null +++ b/tests/models/oneformer/test_modeling_oneformer.py @@ -0,0 +1,536 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch OneFormer model. """ + +import copy +import inspect +import unittest + +import numpy as np + +from tests.test_modeling_common import floats_tensor +from transformers import OneFormerConfig, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.utils import cached_property + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin + + +if is_torch_available(): + import torch + + from transformers import OneFormerForUniversalSegmentation, OneFormerModel + + if is_vision_available(): + from transformers import OneFormerProcessor + +if is_vision_available(): + from PIL import Image + + +def _config_zero_init(config): + configs_no_init = copy.deepcopy(config) + for key in configs_no_init.__dict__.keys(): + if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: + setattr(configs_no_init, key, 1e-10) + return configs_no_init + + +class OneFormerModelTester: + def __init__( + self, + parent, + batch_size=2, + is_training=True, + use_auxiliary_loss=False, + num_queries=10, + num_channels=3, + min_size=32 * 8, + max_size=32 * 8, + num_labels=4, + hidden_dim=64, + sequence_length=77, + n_ctx=4, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_auxiliary_loss = use_auxiliary_loss + self.num_queries = num_queries + self.num_channels = num_channels + self.min_size = min_size + self.max_size = max_size + self.num_labels = num_labels + self.hidden_dim = hidden_dim + self.sequence_length = sequence_length + self.n_ctx = n_ctx + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to( + torch_device + ) + + task_inputs = torch.randint(high=49408, size=(self.batch_size, self.sequence_length)).to(torch_device).long() + + pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device) + + text_inputs = ( + torch.randint(high=49408, size=(self.batch_size, self.num_queries - self.n_ctx, self.sequence_length)) + .to(torch_device) + .long() + ) + + mask_labels = ( + torch.rand([self.batch_size, self.num_labels, self.min_size, self.max_size], device=torch_device) > 0.5 + ).float() + class_labels = (torch.rand((self.batch_size, self.num_labels), device=torch_device) > 0.5).long() + + config = self.get_config() + return config, pixel_values, task_inputs, text_inputs, pixel_mask, mask_labels, class_labels + + def get_config(self): + config = OneFormerConfig( + hidden_size=self.hidden_dim, + ) + + config.num_queries = self.num_queries + config.num_labels = self.num_labels + + config.backbone_config.depths = [1, 1, 1, 1] + config.backbone_config.num_channels = self.num_channels + + config.encoder_feedforward_dim = 64 + config.dim_feedforward = 128 + config.hidden_dim = self.hidden_dim + config.mask_dim = self.hidden_dim + config.conv_dim = self.hidden_dim + + config.text_encoder_width = self.hidden_dim + config.task_seq_len = self.sequence_length + config.max_seq_len = self.sequence_length + config.text_encoder_context_length = self.sequence_length + config.text_encoder_n_ctx = self.n_ctx + + return config + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, task_inputs, pixel_mask, _, _, _ = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask, "task_inputs": task_inputs} + return config, inputs_dict + + def check_output_hidden_state(self, output, config): + encoder_hidden_states = output.encoder_hidden_states + pixel_decoder_hidden_states = output.pixel_decoder_hidden_states + transformer_decoder_hidden_states = output.transformer_decoder_hidden_states + + self.parent.assertTrue(len(encoder_hidden_states), len(config.backbone_config.depths)) + self.parent.assertTrue(len(pixel_decoder_hidden_states), config.encoder_layers) + self.parent.assertTrue(len(transformer_decoder_hidden_states), config.decoder_layers - 1) + + def create_and_check_oneformer_model( + self, config, pixel_values, task_inputs, pixel_mask, output_hidden_states=False + ): + with torch.no_grad(): + model = OneFormerModel(config=config) + model.to(torch_device) + model.eval() + + output = model(pixel_values=pixel_values, task_inputs=task_inputs, pixel_mask=pixel_mask) + output = model(pixel_values, task_inputs=task_inputs, output_hidden_states=True) + # the correct shape of output.transformer_decoder_hidden_states ensure the correcteness of the + # encoder and pixel decoder + self.parent.assertEqual( + output.transformer_decoder_object_queries.shape, + (self.batch_size, self.num_queries, self.hidden_dim), + ) + # let's ensure the other two hidden state exists + self.parent.assertTrue(output.pixel_decoder_hidden_states is not None) + self.parent.assertTrue(output.encoder_hidden_states is not None) + + if output_hidden_states: + self.check_output_hidden_state(output, config) + + def create_and_check_oneformer_universal_segmentation_head_model( + self, config, pixel_values, task_inputs, text_inputs, pixel_mask, mask_labels, class_labels + ): + model = OneFormerForUniversalSegmentation(config=config) + model.to(torch_device) + model.eval() + + def comm_check_on_output(result): + # let's still check that all the required stuff is there + self.parent.assertTrue(result.transformer_decoder_hidden_states is not None) + self.parent.assertTrue(result.pixel_decoder_hidden_states is not None) + self.parent.assertTrue(result.encoder_hidden_states is not None) + # okay, now we need to check the logits shape + # due to the encoder compression, masks have a //4 spatial size + self.parent.assertEqual( + result.masks_queries_logits.shape, + (self.batch_size, self.num_queries, self.min_size // 4, self.max_size // 4), + ) + # + 1 for null class + self.parent.assertEqual( + result.class_queries_logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1) + ) + + with torch.no_grad(): + result = model(pixel_values=pixel_values, task_inputs=task_inputs, pixel_mask=pixel_mask) + result = model(pixel_values, task_inputs) + + comm_check_on_output(result) + + config.is_training = True + model = OneFormerForUniversalSegmentation(config=config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + result = model( + pixel_values=pixel_values, + task_inputs=task_inputs, + pixel_mask=pixel_mask, + mask_labels=mask_labels, + class_labels=class_labels, + text_inputs=text_inputs, + ) + + comm_check_on_output(result) + + self.parent.assertTrue(result.loss is not None) + self.parent.assertEqual(result.loss.shape, torch.Size([1])) + + +@require_torch +class OneFormerModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (OneFormerModel, OneFormerForUniversalSegmentation) if is_torch_available() else () + + is_encoder_decoder = False + test_pruning = False + test_head_masking = False + test_missing_keys = False + + def setUp(self): + self.model_tester = OneFormerModelTester(self) + self.config_tester = ConfigTester(self, config_class=OneFormerConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_oneformer_model(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.create_and_check_oneformer_model(config, **inputs, output_hidden_states=False) + + def test_oneformer_universal_segmentation_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_oneformer_universal_segmentation_head_model(*config_and_inputs) + + def test_model_main_input_name(self): + for model_class in self.all_model_classes: + model_signature = inspect.signature(getattr(model_class, "forward")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1:3] + self.assertEqual(model_class.main_input_name, observed_main_input_name) + + @unittest.skip(reason="OneFormer uses two main inputs") + def test_torchscript_simple(self): + pass + + @unittest.skip(reason="OneFormer uses two main inputs") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip(reason="OneFormer uses two main inputs") + def test_torchscript_output_hidden_state(self): + pass + + @unittest.skip(reason="OneFormer does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="OneFormer does not have a get_input_embeddings method") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="OneFormer is not a generative model") + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="OneFormer does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @require_torch_multi_gpu + @unittest.skip( + reason="OneFormer has some layers using `add_module` which doesn't work well with `nn.DataParallel`" + ) + def test_multi_gpu_data_parallel_forward(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values", "task_inputs"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + @slow + def test_model_from_pretrained(self): + for model_name in ["shi-labs/oneformer_ade20k_swin_tiny"]: + model = OneFormerModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_model_with_labels(self): + size = (self.model_tester.min_size,) * 2 + inputs = { + "pixel_values": torch.randn((2, 3, *size), device=torch_device), + "task_inputs": torch.randint(high=49408, size=(2, 77), device=torch_device).long(), + "text_inputs": torch.randint(high=49408, size=(2, 134, 77), device=torch_device).long(), + "mask_labels": torch.randn((2, 150, *size), device=torch_device), + "class_labels": torch.zeros(2, 150, device=torch_device).long(), + } + + config = OneFormerConfig() + config.is_training = True + + model = OneFormerForUniversalSegmentation(config).to(torch_device) + outputs = model(**inputs) + self.assertTrue(outputs.loss is not None) + + def test_hidden_states_output(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.create_and_check_oneformer_model(config, **inputs, output_hidden_states=True) + + def test_attention_outputs(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + outputs = model(**inputs, output_attentions=True) + self.assertTrue(outputs.attentions is not None) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.contrastive_temperature = 1 + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_training(self): + if not self.model_tester.is_training: + return + # only OneFormerForUniversalSegmentation has the loss + model_class = self.all_model_classes[1] + ( + config, + pixel_values, + task_inputs, + text_inputs, + pixel_mask, + mask_labels, + class_labels, + ) = self.model_tester.prepare_config_and_inputs() + config.is_training = True + + model = model_class(config) + model.to(torch_device) + model.train() + + loss = model( + pixel_values, task_inputs, text_inputs=text_inputs, mask_labels=mask_labels, class_labels=class_labels + ).loss + loss.backward() + + def test_retain_grad_hidden_states_attentions(self): + # only OneFormerForUniversalSegmentation has the loss + model_class = self.all_model_classes[1] + ( + config, + pixel_values, + task_inputs, + text_inputs, + pixel_mask, + mask_labels, + class_labels, + ) = self.model_tester.prepare_config_and_inputs() + config.output_hidden_states = True + config.output_attentions = True + config.is_training = True + + model = model_class(config) + model.to(torch_device) + model.train() + + outputs = model( + pixel_values, task_inputs, text_inputs=text_inputs, mask_labels=mask_labels, class_labels=class_labels + ) + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states[0] + pixel_decoder_hidden_states.retain_grad() + + transformer_decoder_class_predictions = outputs.transformer_decoder_class_predictions + transformer_decoder_class_predictions.retain_grad() + + transformer_decoder_mask_predictions = outputs.transformer_decoder_mask_predictions + transformer_decoder_mask_predictions.retain_grad() + + attentions = outputs.attentions[0][0] + attentions.retain_grad() + + outputs.loss.backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(pixel_decoder_hidden_states.grad) + self.assertIsNotNone(transformer_decoder_class_predictions.grad) + self.assertIsNotNone(transformer_decoder_mask_predictions.grad) + self.assertIsNotNone(attentions.grad) + + +TOLERANCE = 1e-4 + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_vision +@slow +class OneFormerModelIntegrationTest(unittest.TestCase): + @cached_property + def model_checkpoints(self): + return "shi-labs/oneformer_ade20k_swin_tiny" + + @cached_property + def default_processor(self): + return OneFormerProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None + + def test_inference_no_head(self): + model = OneFormerModel.from_pretrained(self.model_checkpoints).to(torch_device) + processor = self.default_processor + image = prepare_img() + inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device) + inputs_shape = inputs["pixel_values"].shape + # check size + self.assertEqual(inputs_shape, (1, 3, 512, 682)) + + task_inputs_shape = inputs["task_inputs"].shape + # check size + self.assertEqual(task_inputs_shape, (1, 77)) + + with torch.no_grad(): + outputs = model(**inputs) + + expected_slice_hidden_state = torch.tensor( + [[0.2724, 0.8287, 0.6025], [1.2706, 1.1252, 1.1445], [1.1357, 0.6150, 0.4185]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.encoder_hidden_states[-1][0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + expected_slice_hidden_state = torch.tensor( + [[1.0581, 1.2275, 1.2000], [1.1901, 1.2925, 1.2861], [1.1578, 1.2558, 1.3212]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.pixel_decoder_hidden_states[0][0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + expected_slice_hidden_state = torch.tensor( + [[3.0711, -1.1855, -5.1095], [3.5536, -3.2710, -5.2052], [2.6020, -4.3605, -4.1422]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.transformer_decoder_class_predictions[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + def test_inference_universal_segmentation_head(self): + model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + processor = self.default_processor + image = prepare_img() + inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device) + inputs_shape = inputs["pixel_values"].shape + # check size + self.assertEqual(inputs_shape, (1, 3, 512, 682)) + + with torch.no_grad(): + outputs = model(**inputs) + + # masks_queries_logits + masks_queries_logits = outputs.masks_queries_logits + self.assertEqual( + masks_queries_logits.shape, + (1, model.config.num_queries, inputs_shape[-2] // 4, (inputs_shape[-1] + 2) // 4), + ) + expected_slice = [[[3.1215, 4.1250, 4.1106], [2.8183, 3.4623, 3.5512], [2.4550, 2.9841, 3.5081]]] + expected_slice = torch.tensor(expected_slice).to(torch_device) + self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE)) + # class_queries_logits + class_queries_logits = outputs.class_queries_logits + self.assertEqual( + class_queries_logits.shape, + (1, model.config.num_queries, model.config.num_labels + 1), + ) + expected_slice = torch.tensor( + [[3.0711, -1.1855, -5.1095], [3.5536, -3.2710, -5.2052], [2.6020, -4.3605, -4.1422]] + ).to(torch_device) + self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_with_segmentation_maps_and_loss(self): + dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints) + processor = self.default_processor + processor.image_processor.num_text = dummy_model.config.num_queries - dummy_model.config.text_encoder_n_ctx + dummy_model.config.is_training = True + model = OneFormerForUniversalSegmentation(dummy_model.config).to(torch_device).eval() + del dummy_model + + inputs = processor( + [np.zeros((3, 512, 640)), np.zeros((3, 512, 640))], + ["semantic", "semantic"], + segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)], + return_tensors="pt", + ) + + inputs["pixel_values"] = inputs["pixel_values"].to(torch_device) + inputs["task_inputs"] = inputs["task_inputs"].to(torch_device) + inputs["text_inputs"] = inputs["text_inputs"].to(torch_device) + inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]] + inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]] + + with torch.no_grad(): + outputs = model(**inputs) + + self.assertTrue(outputs.loss is not None) diff --git a/tests/models/oneformer/test_processor_oneformer.py b/tests/models/oneformer/test_processor_oneformer.py new file mode 100644 index 000000000000..72d940df8b18 --- /dev/null +++ b/tests/models/oneformer/test_processor_oneformer.py @@ -0,0 +1,833 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import tempfile +import unittest + +import numpy as np +from datasets import load_dataset + +from huggingface_hub import hf_hub_download +from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_feature_extraction_common import prepare_image_inputs + + +if is_torch_available(): + import torch + + if is_vision_available(): + from transformers import CLIPTokenizer, OneFormerImageProcessor, OneFormerProcessor + from transformers.models.oneformer.image_processing_oneformer import binary_mask_to_rle + from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput + +if is_vision_available(): + from PIL import Image + + +def prepare_metadata(class_info_file, repo_path="shi-labs/oneformer_demo"): + with open(hf_hub_download(repo_path, class_info_file, repo_type="dataset"), "r") as f: + class_info = json.load(f) + metadata = {} + class_names = [] + thing_ids = [] + + for key, info in class_info.items(): + metadata[key] = info["name"] + class_names.append(info["name"]) + if info["isthing"]: + thing_ids.append(int(key)) + + metadata["thing_ids"] = thing_ids + metadata["class_names"] = class_names + return metadata + + +class OneFormerProcessorTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + size=None, + do_resize=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + num_labels=10, + reduce_labels=False, + ignore_index=255, + max_seq_length=77, + task_seq_length=77, + model_repo="shi-labs/oneformer_ade20k_swin_tiny", + class_info_file="ade20k_panoptic.json", + num_text=10, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = {"shortest_edge": 32, "longest_edge": 1333} if size is None else size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.max_seq_length = max_seq_length + self.task_seq_length = task_seq_length + self.class_info_file = class_info_file + self.metadata = prepare_metadata(class_info_file) + self.num_text = num_text + self.model_repo = model_repo + + # for the post_process_functions + self.batch_size = 2 + self.num_queries = 10 + self.num_classes = 10 + self.height = 3 + self.width = 4 + self.num_labels = num_labels + self.reduce_labels = reduce_labels + self.ignore_index = ignore_index + + def prepare_processor_dict(self): + image_processor_dict = { + "do_resize": self.do_resize, + "size": self.size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "num_labels": self.num_labels, + "reduce_labels": self.reduce_labels, + "ignore_index": self.ignore_index, + "class_info_file": self.class_info_file, + "metadata": self.metadata, + "num_text": self.num_text, + } + + image_processor = OneFormerImageProcessor(**image_processor_dict) + tokenizer = CLIPTokenizer.from_pretrained(self.model_repo) + + return { + "image_processor": image_processor, + "tokenizer": tokenizer, + "max_seq_length": self.max_seq_length, + "task_seq_length": self.task_seq_length, + } + + def get_expected_values(self, image_inputs, batched=False): + """ + This function computes the expected height and width when providing images to OneFormerProcessor, + assuming do_resize is set to True with a scalar size. It also provides the expected sequence length + for the task_inputs and text_list_input. + """ + if not batched: + image = image_inputs[0] + if isinstance(image, Image.Image): + w, h = image.size + else: + h, w = image.shape[1], image.shape[2] + if w < h: + expected_height = int(self.size["shortest_edge"] * h / w) + expected_width = self.size["shortest_edge"] + elif w > h: + expected_height = self.size["shortest_edge"] + expected_width = int(self.size["shortest_edge"] * w / h) + else: + expected_height = self.size["shortest_edge"] + expected_width = self.size["shortest_edge"] + + else: + expected_values = [] + for image in image_inputs: + expected_height, expected_width, expected_sequence_length = self.get_expected_values([image]) + expected_values.append((expected_height, expected_width, expected_sequence_length)) + expected_height = max(expected_values, key=lambda item: item[0])[0] + expected_width = max(expected_values, key=lambda item: item[1])[1] + + expected_sequence_length = self.max_seq_length + + return expected_height, expected_width, expected_sequence_length + + def get_fake_oneformer_outputs(self): + return OneFormerForUniversalSegmentationOutput( + # +1 for null class + class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)), + masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)), + ) + + +@require_torch +@require_vision +class OneFormerProcessingTest(unittest.TestCase): + processing_class = OneFormerProcessor if (is_vision_available() and is_torch_available()) else None + # only for test_feat_extracttion_common.test_feat_extract_to_json_string + feature_extraction_class = processing_class + + def setUp(self): + self.processing_tester = OneFormerProcessorTester(self) + + @property + def processor_dict(self): + return self.processing_tester.prepare_processor_dict() + + def test_feat_extract_properties(self): + processor = self.processing_class(**self.processor_dict) + self.assertTrue(hasattr(processor, "image_processor")) + self.assertTrue(hasattr(processor, "tokenizer")) + self.assertTrue(hasattr(processor, "max_seq_length")) + self.assertTrue(hasattr(processor, "task_seq_length")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize processor + processor = self.processing_class(**self.processor_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + + expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( + image_inputs + ) + + self.assertEqual( + encoded_images.shape, + (1, self.processing_tester.num_channels, expected_height, expected_width), + ) + + tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs + + self.assertEqual( + tokenized_task_inputs.shape, + (1, expected_sequence_length), + ) + + # Test batched + expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( + image_inputs, batched=True + ) + + encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.processing_tester.batch_size, + self.processing_tester.num_channels, + expected_height, + expected_width, + ), + ) + + tokenized_task_inputs = processor( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ).task_inputs + + self.assertEqual( + tokenized_task_inputs.shape, + (self.processing_tester.batch_size, expected_sequence_length), + ) + + def test_call_numpy(self): + # Initialize processor + processor = self.processing_class(**self.processor_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + + expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( + image_inputs + ) + + self.assertEqual( + encoded_images.shape, + (1, self.processing_tester.num_channels, expected_height, expected_width), + ) + + tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs + + self.assertEqual( + tokenized_task_inputs.shape, + (1, expected_sequence_length), + ) + + # Test batched + expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( + image_inputs, batched=True + ) + + encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.processing_tester.batch_size, + self.processing_tester.num_channels, + expected_height, + expected_width, + ), + ) + + tokenized_task_inputs = processor( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ).task_inputs + + self.assertEqual( + tokenized_task_inputs.shape, + (self.processing_tester.batch_size, expected_sequence_length), + ) + + def test_call_pytorch(self): + # Initialize processor + processor = self.processing_class(**self.processor_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + + expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( + image_inputs + ) + + self.assertEqual( + encoded_images.shape, + (1, self.processing_tester.num_channels, expected_height, expected_width), + ) + + tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs + + self.assertEqual( + tokenized_task_inputs.shape, + (1, expected_sequence_length), + ) + + # Test batched + expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( + image_inputs, batched=True + ) + + encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.processing_tester.batch_size, + self.processing_tester.num_channels, + expected_height, + expected_width, + ), + ) + + tokenized_task_inputs = processor( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ).task_inputs + + self.assertEqual( + tokenized_task_inputs.shape, + (self.processing_tester.batch_size, expected_sequence_length), + ) + + def test_equivalence_pad_and_create_pixel_mask(self): + # Initialize processors + processor_1 = self.processing_class(**self.processor_dict) + + image_processor = OneFormerImageProcessor( + do_resize=False, + do_normalize=False, + do_rescale=False, + num_labels=self.processing_tester.num_classes, + class_info_file="ade20k_panoptic.json", + num_text=self.processing_tester.num_text, + ) + tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + processor_2 = self.processing_class( + image_processor=image_processor, tokenizer=tokenizer, max_seq_length=77, task_seq_length=77 + ) + + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test whether the method "pad_and_return_pixel_mask" and calling the image processor return the same tensors + encoded_images_with_method = processor_1.encode_inputs( + image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + ) + encoded_images = processor_2(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt") + + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4) + ) + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4) + ) + + def comm_get_processor_inputs(self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"): + processor = self.processing_class(**self.processor_dict) + # prepare image and target + num_labels = self.processing_tester.num_labels + annotations = None + instance_id_to_semantic_id = None + image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False) + if with_segmentation_maps: + high = num_labels + if is_instance_map: + labels_expanded = list(range(num_labels)) * 2 + instance_id_to_semantic_id = { + instance_id: label_id for instance_id, label_id in enumerate(labels_expanded) + } + annotations = [ + np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs + ] + if segmentation_type == "pil": + annotations = [Image.fromarray(annotation) for annotation in annotations] + + inputs = processor( + image_inputs, + ["semantic"] * len(image_inputs), + annotations, + return_tensors="pt", + instance_id_to_semantic_id=instance_id_to_semantic_id, + pad_and_return_pixel_mask=True, + ) + + return inputs + + def test_init_without_params(self): + pass + + def test_feat_extract_from_and_save_pretrained(self): + feat_extract_first = self.feature_extraction_class(**self.processor_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + feat_extract_first.save_pretrained(tmpdirname) + check_json_file_has_correct_format(os.path.join(tmpdirname, "preprocessor_config.json")) + feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) + + self.assertEqual(feat_extract_second.image_processor.to_dict(), feat_extract_first.image_processor.to_dict()) + self.assertIsInstance(feat_extract_first.image_processor, OneFormerImageProcessor) + self.assertIsInstance(feat_extract_first.tokenizer, CLIPTokenizer) + + def test_call_with_segmentation_maps(self): + def common(is_instance_map=False, segmentation_type=None): + inputs = self.comm_get_processor_inputs( + with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type + ) + + mask_labels = inputs["mask_labels"] + class_labels = inputs["class_labels"] + pixel_values = inputs["pixel_values"] + text_inputs = inputs["text_inputs"] + + # check the batch_size + for mask_label, class_label, text_input in zip(mask_labels, class_labels, text_inputs): + self.assertEqual(mask_label.shape[0], class_label.shape[0]) + # this ensure padding has happened + self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:]) + self.assertEqual(text_input.shape[0], self.processing_tester.num_text) + + common() + common(is_instance_map=True) + common(is_instance_map=False, segmentation_type="pil") + common(is_instance_map=True, segmentation_type="pil") + + def test_integration_semantic_segmentation(self): + # load 2 images and corresponding panoptic annotations from the hub + dataset = load_dataset("nielsr/ade20k-panoptic-demo") + image1 = dataset["train"][0]["image"] + image2 = dataset["train"][1]["image"] + segments_info1 = dataset["train"][0]["segments_info"] + segments_info2 = dataset["train"][1]["segments_info"] + annotation1 = dataset["train"][0]["label"] + annotation2 = dataset["train"][1]["label"] + + def rgb_to_id(color): + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + def create_panoptic_map(annotation, segments_info): + annotation = np.array(annotation) + # convert RGB to segment IDs per pixel + # 0 is the "ignore" label, for which we don't need to make binary masks + panoptic_map = rgb_to_id(annotation) + + # create mapping between segment IDs and semantic classes + inst2class = {segment["id"]: segment["category_id"] for segment in segments_info} + + return panoptic_map, inst2class + + panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1) + panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2) + + image_processor = OneFormerImageProcessor( + reduce_labels=True, + ignore_index=0, + size=(512, 512), + class_info_file="ade20k_panoptic.json", + num_text=self.processing_tester.num_text, + ) + + tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + + processor = OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + max_seq_length=77, + task_seq_length=77, + ) + + # prepare the images and annotations + pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)] + inputs = processor.encode_inputs( + pixel_values_list, + ["semantic", "semantic"], + [panoptic_map1, panoptic_map2], + instance_id_to_semantic_id=[inst2class1, inst2class2], + return_tensors="pt", + ) + + # verify the pixel values, task inputs, text inputs and pixel mask + self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711)) + self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711)) + self.assertEqual(inputs["task_inputs"].shape, (2, 77)) + self.assertEqual(inputs["text_inputs"].shape, (2, self.processing_tester.num_text, 77)) + + # verify the class labels + self.assertEqual(len(inputs["class_labels"]), 2) + # fmt: off + expected_class_labels = torch.tensor([4, 17, 32, 42, 12, 3, 5, 0, 43, 96, 104, 31, 125, 138, 87, 149]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][0], expected_class_labels)) + # fmt: off + expected_class_labels = torch.tensor([19, 67, 82, 17, 12, 42, 3, 14, 5, 0, 115, 43, 8, 138, 125, 143]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels)) + + # verify the task inputs + self.assertEqual(len(inputs["task_inputs"]), 2) + self.assertEqual(inputs["task_inputs"][0].sum().item(), 141082) + self.assertEqual(inputs["task_inputs"][0].sum().item(), inputs["task_inputs"][1].sum().item()) + + # verify the text inputs + self.assertEqual(len(inputs["text_inputs"]), 2) + self.assertEqual(inputs["text_inputs"][0].sum().item(), 1095752) + self.assertEqual(inputs["text_inputs"][1].sum().item(), 1062468) + + # verify the mask labels + self.assertEqual(len(inputs["mask_labels"]), 2) + self.assertEqual(inputs["mask_labels"][0].shape, (16, 512, 711)) + self.assertEqual(inputs["mask_labels"][1].shape, (16, 512, 711)) + self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0) + self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0) + + def test_integration_instance_segmentation(self): + # load 2 images and corresponding panoptic annotations from the hub + dataset = load_dataset("nielsr/ade20k-panoptic-demo") + image1 = dataset["train"][0]["image"] + image2 = dataset["train"][1]["image"] + segments_info1 = dataset["train"][0]["segments_info"] + segments_info2 = dataset["train"][1]["segments_info"] + annotation1 = dataset["train"][0]["label"] + annotation2 = dataset["train"][1]["label"] + + def rgb_to_id(color): + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + def create_panoptic_map(annotation, segments_info): + annotation = np.array(annotation) + # convert RGB to segment IDs per pixel + # 0 is the "ignore" label, for which we don't need to make binary masks + panoptic_map = rgb_to_id(annotation) + + # create mapping between segment IDs and semantic classes + inst2class = {segment["id"]: segment["category_id"] for segment in segments_info} + + return panoptic_map, inst2class + + panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1) + panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2) + + image_processor = OneFormerImageProcessor( + reduce_labels=True, + ignore_index=0, + size=(512, 512), + class_info_file="ade20k_panoptic.json", + num_text=self.processing_tester.num_text, + ) + + tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + + processor = OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + max_seq_length=77, + task_seq_length=77, + ) + + # prepare the images and annotations + pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)] + inputs = processor.encode_inputs( + pixel_values_list, + ["instance", "instance"], + [panoptic_map1, panoptic_map2], + instance_id_to_semantic_id=[inst2class1, inst2class2], + return_tensors="pt", + ) + + # verify the pixel values, task inputs, text inputs and pixel mask + self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711)) + self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711)) + self.assertEqual(inputs["task_inputs"].shape, (2, 77)) + self.assertEqual(inputs["text_inputs"].shape, (2, self.processing_tester.num_text, 77)) + + # verify the class labels + self.assertEqual(len(inputs["class_labels"]), 2) + # fmt: off + expected_class_labels = torch.tensor([32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 43, 43, 43, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][0], expected_class_labels)) + # fmt: off + expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 12, 12, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels)) + + # verify the task inputs + self.assertEqual(len(inputs["task_inputs"]), 2) + self.assertEqual(inputs["task_inputs"][0].sum().item(), 144985) + self.assertEqual(inputs["task_inputs"][0].sum().item(), inputs["task_inputs"][1].sum().item()) + + # verify the text inputs + self.assertEqual(len(inputs["text_inputs"]), 2) + self.assertEqual(inputs["text_inputs"][0].sum().item(), 1037040) + self.assertEqual(inputs["text_inputs"][1].sum().item(), 1044078) + + # verify the mask labels + self.assertEqual(len(inputs["mask_labels"]), 2) + self.assertEqual(inputs["mask_labels"][0].shape, (73, 512, 711)) + self.assertEqual(inputs["mask_labels"][1].shape, (57, 512, 711)) + self.assertEqual(inputs["mask_labels"][0].sum().item(), 35040.0) + self.assertEqual(inputs["mask_labels"][1].sum().item(), 98228.0) + + def test_integration_panoptic_segmentation(self): + # load 2 images and corresponding panoptic annotations from the hub + dataset = load_dataset("nielsr/ade20k-panoptic-demo") + image1 = dataset["train"][0]["image"] + image2 = dataset["train"][1]["image"] + segments_info1 = dataset["train"][0]["segments_info"] + segments_info2 = dataset["train"][1]["segments_info"] + annotation1 = dataset["train"][0]["label"] + annotation2 = dataset["train"][1]["label"] + + def rgb_to_id(color): + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + def create_panoptic_map(annotation, segments_info): + annotation = np.array(annotation) + # convert RGB to segment IDs per pixel + # 0 is the "ignore" label, for which we don't need to make binary masks + panoptic_map = rgb_to_id(annotation) + + # create mapping between segment IDs and semantic classes + inst2class = {segment["id"]: segment["category_id"] for segment in segments_info} + + return panoptic_map, inst2class + + panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1) + panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2) + + image_processor = OneFormerImageProcessor( + reduce_labels=True, + ignore_index=0, + size=(512, 512), + class_info_file="ade20k_panoptic.json", + num_text=self.processing_tester.num_text, + ) + + tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + + processor = OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + max_seq_length=77, + task_seq_length=77, + ) + + # prepare the images and annotations + pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)] + inputs = processor.encode_inputs( + pixel_values_list, + ["panoptic", "panoptic"], + [panoptic_map1, panoptic_map2], + instance_id_to_semantic_id=[inst2class1, inst2class2], + return_tensors="pt", + ) + + # verify the pixel values, task inputs, text inputs and pixel mask + self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711)) + self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711)) + self.assertEqual(inputs["task_inputs"].shape, (2, 77)) + self.assertEqual(inputs["text_inputs"].shape, (2, self.processing_tester.num_text, 77)) + + # verify the class labels + self.assertEqual(len(inputs["class_labels"]), 2) + # fmt: off + expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][0], expected_class_labels)) + # fmt: off + expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels)) + + # verify the task inputs + self.assertEqual(len(inputs["task_inputs"]), 2) + self.assertEqual(inputs["task_inputs"][0].sum().item(), 136240) + self.assertEqual(inputs["task_inputs"][0].sum().item(), inputs["task_inputs"][1].sum().item()) + + # verify the text inputs + self.assertEqual(len(inputs["text_inputs"]), 2) + self.assertEqual(inputs["text_inputs"][0].sum().item(), 1048653) + self.assertEqual(inputs["text_inputs"][1].sum().item(), 1067160) + + # verify the mask labels + self.assertEqual(len(inputs["mask_labels"]), 2) + self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711)) + self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711)) + self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0) + self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0) + + def test_binary_mask_to_rle(self): + fake_binary_mask = np.zeros((20, 50)) + fake_binary_mask[0, 20:] = 1 + fake_binary_mask[1, :15] = 1 + fake_binary_mask[5, :10] = 1 + + rle = binary_mask_to_rle(fake_binary_mask) + self.assertEqual(len(rle), 4) + self.assertEqual(rle[0], 21) + self.assertEqual(rle[1], 45) + + def test_post_process_semantic_segmentation(self): + image_processor = OneFormerImageProcessor( + reduce_labels=True, + ignore_index=0, + size=(512, 512), + class_info_file="ade20k_panoptic.json", + num_text=self.processing_tester.num_text, + ) + tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + processor = OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + max_seq_length=77, + task_seq_length=77, + ) + + outputs = self.processing_tester.get_fake_oneformer_outputs() + + segmentation = processor.post_process_semantic_segmentation(outputs) + + self.assertEqual(len(segmentation), self.processing_tester.batch_size) + self.assertEqual( + segmentation[0].shape, + ( + self.processing_tester.height, + self.processing_tester.width, + ), + ) + + target_sizes = [(1, 4) for i in range(self.processing_tester.batch_size)] + segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) + + self.assertEqual(segmentation[0].shape, target_sizes[0]) + + def test_post_process_instance_segmentation(self): + image_processor = OneFormerImageProcessor( + reduce_labels=True, + ignore_index=0, + size=(512, 512), + class_info_file="ade20k_panoptic.json", + num_text=self.processing_tester.num_text, + ) + tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + processor = OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + max_seq_length=77, + task_seq_length=77, + ) + + outputs = self.processing_tester.get_fake_oneformer_outputs() + segmentation = processor.post_process_instance_segmentation(outputs, threshold=0) + + self.assertTrue(len(segmentation) == self.processing_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(el["segmentation"].shape, (self.processing_tester.height, self.processing_tester.width)) + + def test_post_process_panoptic_segmentation(self): + image_processor = OneFormerImageProcessor( + reduce_labels=True, + ignore_index=0, + size=(512, 512), + class_info_file="ade20k_panoptic.json", + num_text=self.processing_tester.num_text, + ) + tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + processor = OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + max_seq_length=77, + task_seq_length=77, + ) + + outputs = self.processing_tester.get_fake_oneformer_outputs() + segmentation = processor.post_process_panoptic_segmentation(outputs, threshold=0) + + self.assertTrue(len(segmentation) == self.processing_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(el["segmentation"].shape, (self.processing_tester.height, self.processing_tester.width)) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 2e4613d918f5..8f009ab4dcb3 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -124,6 +124,8 @@ src/transformers/models/mobilevit/modeling_tf_mobilevit.py src/transformers/models/nat/configuration_nat.py src/transformers/models/nat/modeling_nat.py src/transformers/models/nezha/configuration_nezha.py +src/transformers/models/oneformer/configuration_oneformer.py +src/transformers/models/oneformer/modeling_oneformer.py src/transformers/models/openai/configuration_openai.py src/transformers/models/opt/configuration_opt.py src/transformers/models/opt/modeling_opt.py