first commit
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from transformers import pipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ZeroShotClassifier:
|
||||
"""
|
||||
Classe responsable de la classification de texte à l'aide de modèles Hugging Face Zero-Shot.
|
||||
"""
|
||||
def __init__(self, model_name: str = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"):
|
||||
"""
|
||||
Initialise le classifieur zero-shot.
|
||||
:param model_name: Nom du modèle Hugging Face à utiliser.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self._pipeline = None
|
||||
|
||||
@property
|
||||
def classifier_pipeline(self):
|
||||
"""
|
||||
Initialisation tardive (lazy loading) du pipeline pour économiser de la mémoire et du temps au démarrage.
|
||||
"""
|
||||
if self._pipeline is None:
|
||||
logger.info(f"Chargement du pipeline de classification avec le modèle {self.model_name} (ceci peut prendre quelques secondes)...")
|
||||
# On laisse Hugging Face gérer le choix du device (GPU s'il est dispo, sinon CPU)
|
||||
self._pipeline = pipeline("zero-shot-classification", model=self.model_name)
|
||||
return self._pipeline
|
||||
|
||||
def classify(self, text: str, candidate_labels: List[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Classifie un texte selon une liste de catégories candidates.
|
||||
Si aucune catégorie n'est fournie, utilise les catégories de harcèlement par défaut.
|
||||
:param text: Le texte à classifier.
|
||||
:param candidate_labels: Liste des catégories (labels).
|
||||
:return: Dictionnaire contenant les labels et leurs scores associés.
|
||||
"""
|
||||
if candidate_labels is None:
|
||||
candidate_labels = ["Cyberharcèlement", "Insulte", "Menace", "Non-harcèlement"]
|
||||
|
||||
if not text or not text.strip():
|
||||
# Si le texte est vide, on renvoie une structure vide ou par défaut
|
||||
return {"labels": [], "scores": []}
|
||||
|
||||
try:
|
||||
# On exécute le pipeline de classification
|
||||
result = self.classifier_pipeline(text, candidate_labels=candidate_labels)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la classification du texte : {e}")
|
||||
raise RuntimeError(f"Échec de la classification : {e}") from e
|
||||
Reference in New Issue
Block a user