Spaces:
Paused
Paused
| import spaces | |
| import gradio as gr | |
| import torch | |
| import time | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| # ============================================================ | |
| # CONFIG | |
| # ============================================================ | |
| MODEL_NAME = "rikunarita/Qwen3-4B-Thinking-2507-Genius-Coder" | |
| LORA_MODEL_NAME = "rahul7star/Qwen3-4B-Thinking-2509-Genius-Coder-AI" | |
| MAX_INPUT_TOKENS = 4096 | |
| MAX_NEW_TOKENS = 4096 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ============================================================ | |
| # LOAD TOKENIZER | |
| # ============================================================ | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # ============================================================ | |
| # LOAD BASE MODEL | |
| # ============================================================ | |
| print("Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| ) | |
| base_model.eval() | |
| print("Base model loaded successfully") | |
| # ============================================================ | |
| # OPTIONAL LORA LOAD | |
| # ============================================================ | |
| model = base_model | |
| lora_loaded = False | |
| try: | |
| print("Attempting to load LoRA adapter...") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| LORA_MODEL_NAME, | |
| torch_dtype="auto", | |
| ) | |
| model.eval() | |
| lora_loaded = True | |
| print("β LoRA loaded successfully") | |
| except Exception as e: | |
| print("β οΈ LoRA not loaded:", e) | |
| model = base_model | |
| lora_loaded = False | |
| # ============================================================ | |
| # SYSTEM PROMPT | |
| # ============================================================ | |
| SYSTEM_PROMPT = """You are a professional AI Coding Assistant. | |
| Your responses must be: | |
| - Clear and concise | |
| - Well-structured with headings and bullet points | |
| - Technically accurate | |
| - Written in a formal, professional tone | |
| - Focused on best practices and production-quality code | |
| """ | |
| # ============================================================ | |
| # GENERATION FUNCTION | |
| # ============================================================ | |
| def generate_answer(question, max_tokens, use_lora): | |
| print("\n================ GENERATE ANSWER START ================") | |
| if not question or not question.strip(): | |
| return "Please enter a valid question." | |
| try: | |
| start_time = time.time() | |
| active_model = model if (use_lora and lora_loaded) else base_model | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": question.strip()}, | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKENS, | |
| ).to(DEVICE) | |
| input_token_count = inputs.input_ids.shape[-1] | |
| print(f"Input tokens: {input_token_count}") | |
| max_tokens = min(int(max_tokens), MAX_NEW_TOKENS) | |
| print(f"Final max_new_tokens: {max_tokens}") | |
| print("π Starting generation...") | |
| with torch.no_grad(): | |
| output = active_model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| repetition_penalty=1.05, | |
| use_cache=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| print("β Generation finished") | |
| generated_tokens = output[0][input_token_count:] | |
| response = tokenizer.decode( | |
| generated_tokens, | |
| skip_special_tokens=True, | |
| ) | |
| print(response) | |
| print(f"Generated tokens: {generated_tokens.shape[-1]}") | |
| print(f"β± Total time: {time.time() - start_time:.2f} sec") | |
| print("================ GENERATE ANSWER END ==================\n") | |
| return response.strip() or "No output generated." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return f"Error occurred: {str(e)}" | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ Professional Coding Assistant | |
| **Qwen3-4B + Optional LoRA** | |
| - β‘ Stable GPU inference | |
| - π§ Deterministic responses | |
| - π» Production-quality code | |
| """ | |
| ) | |
| question = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Explain Quick Sort with complexity and a Python example", | |
| value="write a python code using pytorch for a simple neural network demo", | |
| lines=4, | |
| ) | |
| answer = gr.Markdown(label="AI Response", elem_id="answer_box") | |
| max_tokens = gr.Slider( | |
| 64, 4096, value=1024, step=32, label="Max New Tokens" | |
| ) | |
| use_lora = gr.Checkbox( | |
| value=lora_loaded, | |
| label="Enable LoRA Adapter" | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Generate Answer", variant="primary") | |
| copy_btn = gr.Button("π Copy Response") | |
| clear = gr.Button("Clear") | |
| submit.click( | |
| fn=generate_answer, | |
| inputs=[question, max_tokens, use_lora], | |
| outputs=answer, | |
| ) | |
| clear.click( | |
| fn=lambda: ("", ""), | |
| outputs=[question, answer], | |
| ) | |
| copy_btn.click( | |
| fn=None, | |
| js=""" | |
| () => { | |
| const el = document.querySelector('#answer_box'); | |
| navigator.clipboard.writeText(el.innerText); | |
| } | |
| """, | |
| ) | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { max-width: 900px !important; margin: auto; } | |
| textarea { font-size: 14px !important; } | |
| """, | |
| ) |