skip to main content

Attention:

The NSF Public Access Repository (NSF-PAR) system and access will be unavailable from 5:00 PM ET until 11:00 PM ET on Friday, June 21 due to maintenance. We apologize for the inconvenience.


Title: On the Role of Attention in Prompt-tuning
Prompt-tuning is an emerging strategy to adapt large language models (LLM) to downstream tasks by learning a (soft-)prompt parameter from data. Despite its success in LLMs, there is limited theoretical understanding of the power of prompt-tuning and the role of the attention mechanism in prompting. In this work, we explore prompt-tuning for one-layer attention architectures and study contextual mixture-models where each input token belongs to a context-relevant or -irrelevant set. We isolate the role of prompttuning through a self-contained prompt-attention model. Our contributions are as follows: (1) We show that softmax-prompt-attention is provably more expressive than softmax-self-attention and linear-prompt-attention under our contextual data model. (2) We analyze the initial trajectory of gradient descent and show that it learns the prompt and prediction head with near-optimal sample complexity and demonstrate how the prompt can provably attend to sparse context-relevant tokens. (3) Assuming a known prompt but an unknown prediction head, we characterize the exact finite sample performance of prompt-attention which reveals the fundamental performance limits and the precise benefit of the context information. We also provide experiments that verify our theoretical insights on real datasets and demonstrate how prompt-tuning enables the model to attend to context-relevant information.  more » « less
Award ID(s):
1846369 1813877
NSF-PAR ID:
10483606
Author(s) / Creator(s):
; ; ;
Publisher / Repository:
Proceedings of the 40th International Conference on Machine Learning
Date Published:
Format(s):
Medium: X
Location:
Honolulu, USA
Sponsoring Org:
National Science Foundation
More Like this
  1. Multivariate time-series data are frequently observed in critical care settings and are typically characterized by sparsity (missing information) and irregular time intervals. Existing approaches for learning representations in this domain handle these challenges by either aggregation or imputation of values, which in-turn suppresses the fine-grained information and adds undesirable noise/overhead into the machine learning model. To tackle this problem, we propose a S elf-supervised Tra nsformer for T ime- S eries (STraTS) model, which overcomes these pitfalls by treating time-series as a set of observation triplets instead of using the standard dense matrix representation. It employs a novel Continuous Value Embedding technique to encode continuous time and variable values without the need for discretization. It is composed of a Transformer component with multi-head attention layers, which enable it to learn contextual triplet embeddings while avoiding the problems of recurrence and vanishing gradients that occur in recurrent architectures. In addition, to tackle the problem of limited availability of labeled data (which is typically observed in many healthcare applications), STraTS utilizes self-supervision by leveraging unlabeled data to learn better representations by using time-series forecasting as an auxiliary proxy task. Experiments on real-world multivariate clinical time-series benchmark datasets demonstrate that STraTS has better prediction performance than state-of-the-art methods for mortality prediction, especially when labeled data is limited. Finally, we also present an interpretable version of STraTS, which can identify important measurements in the time-series data. Our data preprocessing and model implementation codes are available at https://github.com/sindhura97/STraTS . 
    more » « less
  2. Identifying instances when a user will not able to attend to an incoming message and constructing an auto-response with relevant contextual information may help reduce social pressures to immediately respond that many users face. Mobile messaging behavior often varies from one person to another. As a result, compared to a generic model considering profiles of several users, a personalized model can capture a user's messaging behavior more accurately to predict their inattentive states. However, creating accurate personalized models requires a non-trivial amount of individual data, which is often not available for new users. In this work, we investigate a weighted hybrid approach to model users' attention to messaging. Through dynamic performance-based weighting, we combine the predictions of three types of models, a general model, a group model and a personalized model to create an approach which can work through the lack of initial data while adapting to the user's behavior. We present the details of our modeling approach and the evaluation of the model with over three weeks of data from 274 users. Our results highlight the value of hybrid weighted modeling to predict when a user cannot attend to their messages. 
    more » « less
  3. null (Ed.)
    Context is of fundamental importance to both human and machine vision; e.g., an object in the air is more likely to be an airplane than a pig. The rich notion of context incorporates several aspects including physics rules, statistical co-occurrences, and relative object sizes, among others. While previous work has focused on crowd-sourced out-of-context photographs from the web to study scene context, controlling the nature and extent of contextual violations has been a daunting task. Here we introduce a diverse, synthetic Out-of-Context Dataset (OCD) with fine-grained control over scene context. By leveraging a 3D simulation engine, we systematically control the gravity, object co-occurrences and relative sizes across 36 object categories in a virtual household environment. We conducted a series of experiments to gain insights into the impact of contextual cues on both human and machine vision using OCD. We conducted psychophysics experiments to establish a human benchmark for out-of-context recognition, and then compared it with state-of-the-art computer vision models to quantify the gap between the two. We propose a context-aware recognition transformer model, fusing object and contextual information via multi-head attention. Our model captures useful information for contextual reasoning, enabling human-level performance and better robustness in out-of-context conditions compared to baseline models across OCD and other out-of-context datasets. All source code and data are publicly available at https://github.com/kreimanlab/WhenPigsFlyContext 
    more » « less
  4. Abstract To accurately categorize items, humans learn to selectively attend to the stimulus dimensions that are most relevant to the task. Models of category learning describe how attention changes across trials as labeled stimuli are progressively observed. The Adaptive Attention Representation Model (AARM), for example, provides an account in which categorization decisions are based on the perceptual similarity of a new stimulus to stored exemplars, and dimension-wise attention is updated on every trial in the direction of a feedback-based error gradient. As such, attention modulation as described by AARM requires interactions among processes of orienting, visual perception, memory retrieval, prediction error, and goal maintenance to facilitate learning. The current study explored the neural bases of attention mechanisms using quantitative predictions from AARM to analyze behavioral and fMRI data collected while participants learned novel categories. Generalized linear model analyses revealed patterns of BOLD activation in the parietal cortex (orienting), visual cortex (perception), medial temporal lobe (memory retrieval), basal ganglia (prediction error), and pFC (goal maintenance) that covaried with the magnitude of model-predicted attentional tuning. Results are consistent with AARM's specification of attention modulation as a dynamic property of distributed cognitive systems. 
    more » « less
  5. Many theories assume that a sensory neuron’s higher firing rate indicates a greater probability of its preferred stimulus. However, this contradicts 1) the adaptation phenomena where prolonged exposure to, and thus increased probability of, a stimulus reduces the firing rates of cells tuned to the stimulus; and 2) the observation that unexpected (low probability) stimuli capture attention and increase neuronal firing. Other theories posit that the brain builds predictive/efficient codes for reconstructing sensory inputs. However, they cannot explain that the brain preserves some information while discarding other. We propose that in sensory areas, projection neurons’ firing rates are proportional to optimal code length (i.e., negative log estimated probability), and their spike patterns are the code, for useful features in inputs. This hypothesis explains adaptation-induced changes of V1 orientation tuning curves, and bottom-up attention. We discuss how the modern minimum-description-length (MDL) principle may help understand neural codes. Because regularity extraction is relative to a model class (defined by cells) via its optimal universal code (OUC), MDL matches the brain’s purposeful, hierarchical processing without input reconstruction. Such processing enables input compression/understanding even when model classes do not contain true models. Top-down attention modifies lower-level OUCs via feedback connections to enhance transmission of behaviorally relevant information. Although OUCs concern lossless data compression, we suggest possible extensions to lossy, prefix-free neural codes for prompt, online processing of most important aspects of stimuli while minimizing behaviorally relevant distortion. Finally, we discuss how neural networks might learn MDL’s normalized maximum likelihood (NML) distributions from input data. 
    more » « less