Instructions to use dejanseo/LinkBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use dejanseo/LinkBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="dejanseo/LinkBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("dejanseo/LinkBERT") model = AutoModelForMaskedLM.from_pretrained("dejanseo/LinkBERT") - Notebooks
- Google Colab
- Kaggle
| from transformers import BertForTokenClassification, BertTokenizer, AutoConfig | |
| import torch | |
| from typing import Dict, List, Any | |
| class EndpointHandler: | |
| def __init__(self, path: str = "dejanseo/LinkBERT"): | |
| # Load the configuration from the saved model | |
| self.config = AutoConfig.from_pretrained(path) | |
| self.model = BertForTokenClassification.from_pretrained( | |
| path, | |
| config=self.config | |
| ) | |
| self.model.eval() # Set model to evaluation mode | |
| self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased") | |
| def split_into_chunks(self, text: str, max_length: int = 510) -> List[str]: | |
| """ | |
| Splits the input text into manageable chunks for the tokenizer. | |
| """ | |
| tokens = self.tokenizer.tokenize(text) | |
| chunk_texts = [] | |
| for i in range(0, len(tokens), max_length): | |
| chunk = tokens[i:i+max_length] | |
| chunk_texts.append(self.tokenizer.convert_tokens_to_string(chunk)) | |
| return chunk_texts | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| inputs = data.get("inputs", "") | |
| # Split input text into chunks | |
| chunks = self.split_into_chunks(inputs) | |
| all_results = [] # List to store results from each chunk | |
| for chunk in chunks: | |
| inputs_tensor = self.tokenizer(chunk, return_tensors="pt", add_special_tokens=True) | |
| input_ids = inputs_tensor["input_ids"] | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids) | |
| predictions = torch.argmax(outputs.logits, dim=-1) | |
| tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens | |
| predictions = predictions[0][1:-1].tolist() | |
| # Improved reconstruction to handle "##" artifacts | |
| reconstructed_text = "" | |
| for token, pred in zip(tokens, predictions): | |
| if not token.startswith("##"): | |
| reconstructed_text += " " + token if reconstructed_text else token | |
| else: | |
| reconstructed_text += token[2:] # Remove "##" and append | |
| if pred == 1: # Example condition, adjust as needed | |
| reconstructed_text = reconstructed_text.strip() + "<u>" + token + "</u>" | |
| all_results.append(reconstructed_text.strip()) | |
| # Join the results from each chunk | |
| final_text = " ".join(all_results) | |
| # Return the processed text in a structured format | |
| return [{"text": final_text}] | |