Intent Classification โ Multi-Label (web_search / diagram_enabled)
Fine-tuned from Falconsai/intent_classification
(DistilBERT-base-uncased, apache-2.0) for multi-label binary intent classification.
The original 15-class head was replaced with a 2-label sigmoid head trained with BCEWithLogitsLoss.
Labels
| Index | Label | Meaning |
|---|---|---|
| 0 | web_search |
Query requires a live web search |
| 1 | diagram_enabled |
Query benefits from a diagram / visualisation |
Training details
| Setting | Value |
|---|---|
| Base model | Falconsai/intent_classification (DistilBERT) |
| Problem type | multi_label_classification |
| Frozen layers | embeddings + transformer.layer[0-3] |
| Trainable params | |
| Classifier dropout | 0.3 |
| Learning rate | 5e-6 |
| Early stopping | patience=3 on eval_loss |
| Threshold floor | 0.30 |
| Max sequence length | 128 |
| Split | 80 / 10 / 10 (train / val / test) |
| Seed | 42 |
Decision thresholds
| Label | Threshold |
|---|---|
web_search |
0.35 |
diagram_enabled |
0.6 |
Thresholds are stored in
thresholds.jsonand embedded inconfig.jsonunderconfig.thresholdsโ no separate download needed.
Usage
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
REPO = "aitraineracc/intent-classification-multilabel"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForSequenceClassification.from_pretrained(REPO)
model.eval()
thresholds = model.config.thresholds # {'web_search': 0.35, 'diagram_enabled': 0.6}
def predict(text: str) -> dict:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).squeeze().tolist()
return {
"web_search" : int(probs[0] >= thresholds["web_search"]),
"diagram_enabled" : int(probs[1] >= thresholds["diagram_enabled"]),
"probs" : {"web_search": round(probs[0], 4),
"diagram_enabled": round(probs[1], 4)},
}
print(predict("What is the weather today in Singapore?"))
# {'web_search': 1, 'diagram_enabled': 0, 'probs': ...}
print(predict("Draw me a diagram of how TCP/IP works"))
# {'web_search': 0, 'diagram_enabled': 1, 'probs': ...}
- Downloads last month
- 173
Model tree for aitraineracc/intent-classification-multilabel
Base model
Falconsai/intent_classification