Model Classes§
In addition to implementing the core building blocks of the DAT architecture, the dual-attention package also implements nn.Module classes for several DAT models, covering different task paradigms. In particular, the package provides the following models:
DualAttnTransformerLM: A decoder-only Dual Attention Transformer for language modeling tasks.Seq2SeqDualAttnTransformer: An encoder-decoder Dual Attention Transformer for sequence-to-sequence tasks.VisionDualAttnTransformer: A ViT-style Dual Attention Transformer architecture for vision tasks, expecting an image tensor as input and processing by breaking it down into patches.
These models can be instantiated and used as PyTorch modules in the same way as any other PyTorch model.
Examples§
Below, we provide an example of how to instantiate and use the DualAttnTransformerLM model for language modeling tasks:
from dual_attention.language_models import DualAttnTransformerLM
dat_lm = DualAttnTransformerLM(
vocab_size=50257, # vocabulary size
d_model=1024, # model dimension
n_layers=24, # number of layers
n_heads_sa=8, # number of self-attention heads
n_heads_ra=8, # number of relational attention headsd
dff=None, # feedforward intermediate dimension
dropout_rate=0., # dropout rate
activation='swiglu', # activation function of feedforward block
norm_first=True, # whether to use pre-norm or post-norm
max_block_size=1024, # max context length
symbol_retrieval='symbolic_attention', # type of symbol assignment mechanism
symbol_retrieval_kwargs=dict(d_model=1024, n_heads=8, n_symbols=1024),
pos_enc_type='RoPE' # type of positional encoding to use
)
idx = torch.randint(0, 50257, (1, 1024+1))
x, y = idx[:, :-1], idx[:, 1:]
logits, loss = dat_lm(x, y)
logits # shape: (1, 1024, 50257)
Below, we provide an example of how to instantiate and use the VisionDualAttnTransformer model for vision tasks:
from dual_attention.vision_models import VisionDualAttnTransformer
img_shape = (3, 224, 224)
patch_size = (16, 16)
n_patches = (img_shape[1] // patch_size[0]) * (img_shape[2] // patch_size[1])
dat_vision = VisionDualAttnTransformer(
image_shape=img_shape, # shape of input image
patch_size=patch_size, # size of patch
num_classes=1000, # number of classes
d_model=512, # model dimension
n_layers=6, # number of layers
n_heads_sa=4, # number of self-attention heads
n_heads_ra=4, # number of relational attention heads
dff=2048, # feedforward intermediate dimension
dropout_rate=0.1, # dropout rate
activation='swiglu', # activation function of feedforward block
norm_first=True, # whether to use pre-norm or post-norm
symbol_retrieval='position_relative', # type of symbol assignment mechanism
symbol_retrieval_kwargs=dict(symbol_dim=512, max_rel_pos=n_patches+1),
ra_kwargs=dict(symmetric_rels=True, use_relative_positional_symbols=True),
pool='cls', # type of pooling (class token)
)
img = torch.randn(1, *img_shape)
logits = dat_vision(img)
logits.shape # shape: (1, 1000)
Loading Pre-trained Models from Hugging Face§
The dual_attention.hf module provides a convenient interface to Huggingface Hub for loading pre-trained DAT models or sharing your own trained models.
To load a pre-trained model from Hugging Face, you can simply run the following code:
from dual_attention.hf import DualAttnTransformerLM_HFHub
model = DualAttnTransformerLM_HFHub.from_pretrained("awni00/DAT-sa16-ra16-nr128-ns2048-sh16-nkvh8-1.27B")