Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active September 1, 2024 21:00
Show Gist options
  • Save pszemraj/5f0e3fdb6cc530d6108cb64207bec999 to your computer and use it in GitHub Desktop.
Save pszemraj/5f0e3fdb6cc530d6108cb64207bec999 to your computer and use it in GitHub Desktop.
inference with nvidia's domain classifier
import logging
import os
import fire
import torch
from datasets import load_dataset
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def check_ampere_gpu():
"""
Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does.
"""
# Check if CUDA is available
if not torch.cuda.is_available():
print("No GPU detected, running on CPU.")
return
try:
# Get the compute capability of the GPU
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability(device)
major, minor = capability
# Check if the GPU is Ampere or newer (compute capability >= 8.0)
if major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
gpu_name = torch.cuda.get_device_name(device)
print(
f"{gpu_name} (compute capability {major}.{minor}) supports NVIDIA Ampere or later, "
"enabled TF32 in PyTorch."
)
else:
gpu_name = torch.cuda.get_device_name(device)
print(
f"{gpu_name} (compute capability {major}.{minor}) does not support NVIDIA Ampere or later."
)
except Exception as e:
print(f"Error occurred while checking GPU: {e}")
class DomainModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(DomainModel, self).__init__()
self.model = AutoModel.from_pretrained(config["base_model"])
self.dropout = nn.Dropout(config["fc_dropout"])
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
def forward(self, input_ids, attention_mask):
features = self.model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
dropped = self.dropout(features)
outputs = self.fc(dropped)
return torch.softmax(outputs[:, 0, :], dim=1)
def get_workers():
return int(os.cpu_count() // 2)
def get_device_type(model):
device = str(model.device)
return device.split(":")[0]
def load_model(model_name="nvidia/domain-classifier", device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = DomainModel.from_pretrained(model_name).to(device)
model.eval()
return config, tokenizer, model, device
def classify_batch(batch, tokenizer, model, config, device, text_column):
inputs = tokenizer(
batch[text_column], return_tensors="pt", padding="longest", truncation=True
).to(device)
with torch.no_grad(), torch.autocast(get_device_type(model.model)):
outputs = model(inputs["input_ids"], inputs["attention_mask"])
predicted_classes = torch.argmax(outputs, dim=1)
predicted_labels = [
config.id2label[class_idx.item()] for class_idx in predicted_classes
]
batch["domain_prediction"] = predicted_labels
return batch
def main(
dataset_name: str,
text_column: str = "text",
model_name: str = "nvidia/domain-classifier",
batch_size: int = 32,
):
logger.info(f"Loading dataset: {dataset_name}")
dataset = load_dataset(dataset_name, num_proc=get_workers())
logger.info(f"Dataset loaded: {dataset}")
check_ampere_gpu()
logger.info(f"Loading model: {model_name}")
config, tokenizer, model, device = load_model(model_name)
logger.info("Starting inference")
classified_dataset = dataset.map(
lambda batch: classify_batch(
batch, tokenizer, model, config, device, text_column
),
batched=True,
batch_size=batch_size,
desc="Classifying texts",
)
logger.info("Inference complete")
logger.info("Saving updated dataset")
classified_dataset.save_to_disk("domain_classified_dataset")
logger.info("Processing complete!")
return classified_dataset
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment