Files
classement-image-cyberharce…/tweet_classifier/classifier.py
T
2026-06-28 20:21:40 +02:00

52 lines
2.3 KiB
Python

import logging
from typing import List, Dict, Any, Optional
from transformers import pipeline
logger = logging.getLogger(__name__)
class ZeroShotClassifier:
"""
Classe responsable de la classification de texte à l'aide de modèles Hugging Face Zero-Shot.
"""
def __init__(self, model_name: str = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"):
"""
Initialise le classifieur zero-shot.
:param model_name: Nom du modèle Hugging Face à utiliser.
"""
self.model_name = model_name
self._pipeline = None
@property
def classifier_pipeline(self):
"""
Initialisation tardive (lazy loading) du pipeline pour économiser de la mémoire et du temps au démarrage.
"""
if self._pipeline is None:
logger.info(f"Chargement du pipeline de classification avec le modèle {self.model_name} (ceci peut prendre quelques secondes)...")
# On laisse Hugging Face gérer le choix du device (GPU s'il est dispo, sinon CPU)
self._pipeline = pipeline("zero-shot-classification", model=self.model_name)
return self._pipeline
def classify(self, text: str, candidate_labels: List[str] = None) -> Dict[str, Any]:
"""
Classifie un texte selon une liste de catégories candidates.
Si aucune catégorie n'est fournie, utilise les catégories de harcèlement par défaut.
:param text: Le texte à classifier.
:param candidate_labels: Liste des catégories (labels).
:return: Dictionnaire contenant les labels et leurs scores associés.
"""
if candidate_labels is None:
candidate_labels = ["Cyberharcèlement", "Insulte", "Menace", "Non-harcèlement"]
if not text or not text.strip():
# Si le texte est vide, on renvoie une structure vide ou par défaut
return {"labels": [], "scores": []}
try:
# On exécute le pipeline de classification
result = self.classifier_pipeline(text, candidate_labels=candidate_labels)
return result
except Exception as e:
logger.error(f"Erreur lors de la classification du texte : {e}")
raise RuntimeError(f"Échec de la classification : {e}") from e