VOOZH about

URL: https://huggingface.co/MindscapeRAG/QRRanker

⇱ MindscapeRAG/QRRanker Β· Hugging Face


QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing

🌐 Project Page | πŸ“„ Paper | πŸ€— Models

QRRanker is a lightweight reranking framework that leverages Query-focused Retrieval (QR) heads to produce continuous relevance scores, enabling effective listwise reranking with small-scale models.

Model Description

Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected Query-focused Retrieval (QR) heads. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals.

Our approach provides a listwise solution that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces continuous relevance scores, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision.

Key Features

  • Listwise Reranking: Leverages holistic information within the entire candidate shortlist during ranking
  • Continuous Relevance Scores: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision
  • Selective Head Usage: Focuses on top-performing QR attention heads
  • Layer Truncation: Only the first 25 of 36 layers are retained β€” all QR heads fall within layers 17–24, so deeper layers are unnecessary
  • Memory Enhancement: Optional contextual summaries for improved accuracy on long narratives and dialogues

Architecture

This model is a layer-truncated version of Qwen3-4B-Instruct-2507. The original model has 36 transformer layers, but only the first 25 layers are retained. The top-performing QR heads (layers 17–24) all fall within this range β€” deeper layers contribute no useful QR signal but consume extra computation and memory.

Key design choices in modeling_qwen3_qr.py:

  • Qwen3ConfigGating: Extends Qwen3Config with qr_start_layer, qr_end_layer, qr_head_list, and qr_head_list_mapped (head indices remapped relative to qr_start_layer)
  • Layer construction: Only instantiates qr_end_layer (25) layers instead of all num_hidden_layers (36)
  • No final norm: Skips self.norm(hidden_states) since we only need intermediate KV/query caches, not the final hidden state
  • DynamicCacheWithQuery: Custom KV-cache that additionally stores query states at specified token positions during the forward pass

Default Top-16 QR Heads

Layer-Head: 20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18,
 18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31

All selected heads fall within layers 17–24, which is why truncation to 25 layers is safe.

Model Configuration

Parameter Value Description
qr_start_layer 17 First layer containing QR heads
qr_end_layer 25 Layers 0–24 are retained; layers 25–35 are removed
qr_head_list 16 (layer, head) pairs Top QR heads using original layer indices
qr_head_list_mapped 16 (layer, head) pairs QR heads with layer indices remapped relative to qr_start_layer
num_hidden_layers 36 Original full model depth (config only, not instantiated)
num_attention_heads 32 Attention heads per layer
num_key_value_heads 8 GQA key-value heads per layer

Quick Start

Loading the Model

import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer

# Load model β€” trust_remote_code loads the layer-truncated Qwen3Model
# and Qwen3ConfigGating automatically via auto_map in config.json
config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
model = AutoModel.from_pretrained(
 "MindscapeRAG/QRRanker",
 config=config,
 torch_dtype=torch.float16,
 trust_remote_code=True,
).cuda().eval()

tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker")

QR Score Computation

After a forward pass, QR scores are computed from the cached query and key states:

import math

def repeat_kv(hidden_states, n_rep):
 """Expand KV heads to match query heads (GQA)."""
 batch, num_kv_heads, slen, head_dim = hidden_states.shape
 if n_rep == 1:
 return hidden_states
 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
 return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)


def get_attn_weights(key_states, query_states):
 """Compute softmax attention weights with causal mask."""
 bsz, num_heads, q_len, head_dim = query_states.size()
 num_kv_heads = key_states.size(1)
 key_states = repeat_kv(key_states, num_heads // num_kv_heads)

 scale = 1.0 / math.sqrt(head_dim)
 attn_weights = torch.matmul(query_states * scale, key_states.transpose(2, 3))

 # Causal mask
 seq_len = attn_weights.size(-1)
 causal_mask = torch.ones(num_heads, q_len, seq_len, device=attn_weights.device)
 causal_mask = torch.triu(causal_mask.transpose(-1, -2), diagonal=-(seq_len - q_len)).transpose(-1, -2)
 attn_weights += ((1 - causal_mask) * torch.finfo(attn_weights.dtype).min).unsqueeze(0)

 attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True)
 return torch.exp(attn_weights - attn_lses)


def compute_qr_scores(query_cache, key_cache, qr_head_list, chunk_ranges, query_upper_bound):
 """
 Compute QRRanker relevance scores for document chunks.

 Args:
 query_cache: List[Tensor] β€” query states per layer from DynamicCacheWithQuery
 key_cache: List[Tensor] β€” key states per layer
 qr_head_list: str β€” e.g. "20-15,21-11,17-27,..."
 chunk_ranges: List[[start, end]] β€” token ranges for each chunk
 query_upper_bound: int β€” upper bound of query token positions

 Returns:
 scores: Tensor of shape [num_chunks]
 """
 all_head_scores = []
 for key_state, query_state in zip(key_cache, query_cache):
 attn_weights = get_attn_weights(key_state[:, :, :query_upper_bound, :], query_state)
 attn_weights = attn_weights.mean(dim=-2) # average over query positions
 chunk_scores = torch.stack(
 [attn_weights[:, :, s:e].sum(dim=-1) for s, e in chunk_ranges], dim=2
 )
 all_head_scores.append(chunk_scores)

 # [batch, num_layers, num_heads, num_chunks]
 all_head_scores = torch.stack(all_head_scores, dim=1).float()

 # Select specific QR heads
 if qr_head_list is not None:
 head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')]
 indices = torch.tensor(head_set, device=all_head_scores.device)
 all_head_scores = all_head_scores[:, indices[:, 0], indices[:, 1], :]

 return all_head_scores.sum(dim=1).squeeze(0)

Complete Inference Pipeline

from custom_cache_new import DynamicCacheWithQuery

def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device):
 """
 Rerank candidate paragraphs by QRRanker relevance scores.

 Args:
 model: QRRanker model (loaded with trust_remote_code=True)
 tokenizer: Corresponding tokenizer
 question: Query string
 paragraphs: List of dicts with 'idx', 'title', 'paragraph_text'
 qr_head_list: str β€” e.g. "20-15,21-11,17-27,..."
 device: torch device

 Returns:
 ranked_ids: Paragraph indices sorted by descending relevance
 ranked_scores: Corresponding scores
 """
 # Build input: [chunks] + [query]
 prompt_prefix = '<|im_start|>user\nHere are some retrieved chunks:\n\n'
 chunk_part = prompt_prefix
 chunk_ranges = []

 for i, p in enumerate(paragraphs):
 text = p.get('title', '') + ': ' + p['paragraph_text']
 chunk_part += f"[{i+1}]"
 start = len(chunk_part)
 chunk_part += ' ' + text.strip()
 end = len(chunk_part)
 chunk_ranges.append([start, end])
 chunk_part += '\n\n'

 query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}"
 full_seq = chunk_part + query_part

 # Tokenize
 inputs = tokenizer(full_seq, max_length=262144, truncation=True,
 return_tensors='pt', return_offsets_mapping=True, add_special_tokens=False)
 input_ids = inputs['input_ids'].to(device)
 attention_mask = inputs['attention_mask'].to(device)
 offset_mapping = inputs['offset_mapping'][0]

 # Character-to-token mapping
 char_to_token = {}
 for i, (s, e) in enumerate(offset_mapping):
 for j in range(s, e):
 char_to_token[j] = i

 token_chunk_ranges = [
 [char_to_token.get(s, 0), char_to_token.get(e - 1, 0) + 1]
 for s, e in chunk_ranges
 ]

 query_start = full_seq.index(question)
 query_positions = list(range(
 char_to_token[query_start],
 char_to_token[query_start + len(question) - 1] + 1
 ))
 query_upper_bound = query_positions[-1] + 1

 # Forward pass
 with torch.no_grad():
 past_kv = DynamicCacheWithQuery(query_indices=query_positions)
 output = model(input_ids, attention_mask, past_key_values=past_kv)
 scores = compute_qr_scores(
 output.past_key_values.query_cache,
 output.past_key_values.key_cache,
 qr_head_list, token_chunk_ranges, query_upper_bound
 )

 sorted_idx = torch.argsort(scores, descending=True).cpu().tolist()
 return [paragraphs[i]['idx'] for i in sorted_idx], [float(scores[i]) for i in sorted_idx]

Input Data Format

{
 "id": "sample_001",
 "question": "What is the capital of France?",
 "answer": "Paris",
 "paragraphs": [
 {
 "idx": 0,
 "title": "France",
 "paragraph_text": "Paris is the capital and largest city of France...",
 "is_supporting": true
 }
 ],
 "summary": "Optional summary text..."
}
Field Type Required Description
id string Yes Unique sample identifier
question string Yes User query/question
answer string No Ground truth answer (for evaluation)
paragraphs list Yes List of candidate paragraphs
paragraphs[].idx int Yes Paragraph index
paragraphs[].title string No Paragraph title
paragraphs[].paragraph_text string Yes Paragraph content
paragraphs[].is_supporting bool No Whether it's a supporting paragraph (for evaluation)
summary string No Optional summary information

Environment

Package Version
Python 3.10
torch 2.7.1
transformers 4.53.0
flash-attn (required for flash_attention_2)
safetensors 0.5.3
tokenizers 0.21.2
pip install torch==2.7.1 transformers==4.53.0 safetensors
pip install flash-attn --no-build-isolation

Citation

@misc{li2026queryfocusedmemoryawarererankerlong,
 title={Query-focused and Memory-aware Reranker for Long Context Processing},
 author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou},
 year={2026},
 eprint={2602.12192},
 archivePrefix={arXiv},
 primaryClass={cs.CL},
 url={https://arxiv.org/abs/2602.12192},
}

License

This project is licensed under the Apache 2.0 License.

Downloads last month
30
Safetensors
Model size
3B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for MindscapeRAG/QRRanker

Finetuned
(1744)
this model

Datasets used to train MindscapeRAG/QRRanker

Paper for MindscapeRAG/QRRanker