Intent Classification โ€” Multi-Label (web_search / diagram_enabled)

Fine-tuned from Falconsai/intent_classification (DistilBERT-base-uncased, apache-2.0) for multi-label binary intent classification. The original 15-class head was replaced with a 2-label sigmoid head trained with BCEWithLogitsLoss.

Labels

Index Label Meaning
0 web_search Query requires a live web search
1 diagram_enabled Query benefits from a diagram / visualisation

Training details

Setting Value
Base model Falconsai/intent_classification (DistilBERT)
Problem type multi_label_classification
Frozen layers embeddings + transformer.layer[0-3]
Trainable params 7M / 67M total (10%)
Classifier dropout 0.3
Learning rate 5e-6
Early stopping patience=3 on eval_loss
Threshold floor 0.30
Max sequence length 128
Split 80 / 10 / 10 (train / val / test)
Seed 42

Decision thresholds

Label Threshold
web_search 0.35
diagram_enabled 0.6

Thresholds are stored in thresholds.json and embedded in config.json under config.thresholds โ€” no separate download needed.

Usage

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

REPO = "aitraineracc/intent-classification-multilabel"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model     = AutoModelForSequenceClassification.from_pretrained(REPO)
model.eval()

thresholds = model.config.thresholds  # {'web_search': 0.35, 'diagram_enabled': 0.6}

def predict(text: str) -> dict:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.sigmoid(logits).squeeze().tolist()
    return {
        "web_search"      : int(probs[0] >= thresholds["web_search"]),
        "diagram_enabled" : int(probs[1] >= thresholds["diagram_enabled"]),
        "probs"           : {"web_search": round(probs[0], 4),
                              "diagram_enabled": round(probs[1], 4)},
    }

print(predict("What is the weather today in Singapore?"))
# {'web_search': 1, 'diagram_enabled': 0, 'probs': ...}

print(predict("Draw me a diagram of how TCP/IP works"))
# {'web_search': 0, 'diagram_enabled': 1, 'probs': ...}
Downloads last month
173
Safetensors
Model size
67M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for aitraineracc/intent-classification-multilabel

Finetuned
(1)
this model