52 lines
2.3 KiB
Python
52 lines
2.3 KiB
Python
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
from transformers import pipeline
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ZeroShotClassifier:
|
|
"""
|
|
Classe responsable de la classification de texte à l'aide de modèles Hugging Face Zero-Shot.
|
|
"""
|
|
def __init__(self, model_name: str = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"):
|
|
"""
|
|
Initialise le classifieur zero-shot.
|
|
:param model_name: Nom du modèle Hugging Face à utiliser.
|
|
"""
|
|
self.model_name = model_name
|
|
self._pipeline = None
|
|
|
|
@property
|
|
def classifier_pipeline(self):
|
|
"""
|
|
Initialisation tardive (lazy loading) du pipeline pour économiser de la mémoire et du temps au démarrage.
|
|
"""
|
|
if self._pipeline is None:
|
|
logger.info(f"Chargement du pipeline de classification avec le modèle {self.model_name} (ceci peut prendre quelques secondes)...")
|
|
# On laisse Hugging Face gérer le choix du device (GPU s'il est dispo, sinon CPU)
|
|
self._pipeline = pipeline("zero-shot-classification", model=self.model_name)
|
|
return self._pipeline
|
|
|
|
def classify(self, text: str, candidate_labels: List[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Classifie un texte selon une liste de catégories candidates.
|
|
Si aucune catégorie n'est fournie, utilise les catégories de harcèlement par défaut.
|
|
:param text: Le texte à classifier.
|
|
:param candidate_labels: Liste des catégories (labels).
|
|
:return: Dictionnaire contenant les labels et leurs scores associés.
|
|
"""
|
|
if candidate_labels is None:
|
|
candidate_labels = ["Cyberharcèlement", "Insulte", "Menace", "Non-harcèlement"]
|
|
|
|
if not text or not text.strip():
|
|
# Si le texte est vide, on renvoie une structure vide ou par défaut
|
|
return {"labels": [], "scores": []}
|
|
|
|
try:
|
|
# On exécute le pipeline de classification
|
|
result = self.classifier_pipeline(text, candidate_labels=candidate_labels)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Erreur lors de la classification du texte : {e}")
|
|
raise RuntimeError(f"Échec de la classification : {e}") from e
|