Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)

This commit is contained in:
Wojtab 2023-05-10 01:18:02 +02:00 committed by GitHub
parent a2b25322f0
commit e9e75a9ec7
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 812 additions and 371 deletions

1
.gitignore vendored
View file

@ -4,6 +4,7 @@ training/datasets
extensions/silero_tts/outputs
extensions/elevenlabs_tts/outputs
extensions/sd_api_pictures/outputs
extensions/multimodal/pipelines
logs
loras
models

View file

@ -31,6 +31,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [llama.cpp](docs/llama.cpp-models.md)
* [RWKV model](docs/RWKV-model.md)
* [LoRA (loading and training)](docs/Using-LoRAs.md)
* [Multimodal pipelines, including LLaVA and MiniGPT-4](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal)
* Softprompts
* [Extensions](docs/Extensions.md) - see the [user extensions list](https://github.com/oobabooga/text-generation-webui-extensions)
@ -281,6 +282,12 @@ Optionally, you can use the following command-line flags:
| `--api` | Enable the API extension. |
| `--public-api` | Create a public URL for the API using Cloudfare. |
#### Multimodal
| Flag | Description |
|---------------------------------------|-------------|
| `--multimodal-pipeline PIPELINE` | The multimodal pipeline to use. Examples: `llava-7b`, `llava-13b`. |
Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md).
## Presets

View file

@ -1,4 +1,4 @@
user: "### Human"
bot: "### Assistant"
turn_template: "<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n"
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
user: "### Human:"
bot: "### Assistant:"
turn_template: "<|user|> <|user-message|><|bot|> <|bot-message|>\n"
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.### Human: Hi!### Assistant: Hi there! How can I help you today?\n"

View file

@ -33,7 +33,7 @@ Most of these have been created by the extremely talented contributors that you
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|[llava](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava) | Adds LLaVA multimodal model support. For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava/README.md) in the extension directory. |
|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. |
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. |
|[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. |
@ -50,7 +50,8 @@ Most of these have been created by the extremely talented contributors that you
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
| `def custom_generate_reply(...)` | Overrides the main text generation function. |
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `multimodal` extension for an example |
| `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See `multimodal` extension for an example |
Additionally, the script may define two special global variables:
@ -78,7 +79,7 @@ input_hijack = {
```
This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `llava` extension above for an example.
Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `multimodal` extension above for an example.
## The `bot_prefix_modifier`
@ -100,13 +101,22 @@ Marie Antoinette will become very enthusiastic in all her messages.
In order to use your extension, you must start the web UI with the `--extensions` flag followed by the name of your extension (the folder under `text-generation-webui/extension` where `script.py` resides).
You can activate more than one extension at a time by providing their names separated by spaces. The input, output and bot prefix modifiers will be applied in the specified order. For `custom_generate_chat_prompt`, only the first declaration encountered will be used and the rest will be ignored.
You can activate more than one extension at a time by providing their names separated by spaces. The input, output and bot prefix modifiers will be applied in the specified order.
```
python server.py --extensions enthusiasm translate # First apply enthusiasm, then translate
python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm
```
Do note, that for:
- `custom_generate_chat_prompt`
- `custom_generate_reply`
- `tokenizer_modifier`
- `custom_tokenized_length`
only the first declaration encountered will be used and the rest will be ignored.
## `custom_generate_reply` example
Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet.
@ -167,7 +177,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
# Building the prompt
i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
while i >= 0 and get_encoded_length(''.join(rows)) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
else:
@ -190,7 +200,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
while len(rows) > min_rows and get_encoded_length(''.join(rows)) >= max_length:
rows.pop(1)
prompt = ''.join(rows)

View file

@ -3,7 +3,7 @@ import traceback
from threading import Thread
from typing import Callable, Optional
from modules.text_generation import encode
from modules.text_generation import get_encoded_length
def build_parameters(body):
@ -11,7 +11,7 @@ def build_parameters(body):
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)

View file

@ -1,71 +0,0 @@
# LLaVA
## Description
Adds [LLaVA 13B](https://github.com/haotian-liu/LLaVA) multimodality support to text-generation-webui.
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
## LLaVA-7B
7B version currently isn't supported. It will be supported if/when [more generic multimodality support](https://github.com/oobabooga/text-generation-webui/discussions/1687) gets implemented.
## Usage
To run this extension, download LLaVA weights, for example from [here](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g) (note: it's a 4-bit [GPTQ quantization](https://github.com/oobabooga/text-generation-webui/tree/main/docs/GPTQ-models-(4-bit-mode).md), done on "old CUDA" branch), and then start server.py with `--extensions llava` argument.
Do note, that each image takes up 258 tokens, so adjust max_new_tokens to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. From initial testing, it seems as LLaVA considers the features in all images at the same time, so by default the extension skips previous images. If you want to include them anyway, just tick this checkbox.
## Extension config
This extension uses following parameters (from settings.json):
|Parameter|Description|
|---------|-----------|
|`llava-clip_bits`|Number of bits to load CLIP feature extractor in (either 32 or 16, default=32)|
|`llava-clip_device`|Torch device to run the extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`llava-clip_repo`|Huggingface repository of CLIP model, `openai/clip-vit-large-patch14` by default. There should be no need to change it|
|`llava-projector_bits`|Number of bits to load CLIP->LLaMA feature projector in (either 32 or 16, default=32)|
|`llava-projector_device`|Torch device to run the CLIP->LLaMA feature projector on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`llava-projector_repo`|Huggingface repository of multimodal projector, `liuhaotian/LLaVA-13b-delta-v0` by default. There should be no need to change it|
|`llava-projector_filename`|The filename of multimodal projector weights, `mm_projector.bin` by default. There should be no need to change it|
|`llava-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
## Technical description
### Original LLaVA
The default LLaVA implementation uses modified `transformers` library, however this extension forgoes this requirement. The transformers are modified in LLaVA in such a way, that the entire LLaVA model gets loaded, and the inference now looks as follows:
```
images --> CLIP --> projector --> input embeddings for images --> |
| --> LLaMA
prompt -------------------------> input embeddings for text ----> |
```
The images are represented in the prompt by the following token IDs:
- 32000 - `<im_patch>` - placeholder token for embeddings from projector
- 32001 - `<im_start>` - token marking start of an image
- 32002 - `<im_end>` - token marking end of an image
By default, image will be represented as `<im_start><im_patch>*256<im_end>`. The input embeddings for an image are converted with a single linear layer of the projector, then they are placed instead of `<im_patch>` tokens.
The concatenated prompt then gets fed to fine-tuned LLaMA.
### In this extension
Using default transformers, they only load the LLaMA part of LLaVA, ignoring the added projector weights, and not loading CLIP. We then reconstruct the `images -> CLIP -> projector` pipeline ourselves, then concatenate the input embeddings, and feed it to LLaMA loaded by transformers. This allows us to use normal flow from webui to load this model, and just hijack the model input with additional features.
Splitting it to 3 separate models, allows us to configure each of them, and to move them to different devices(for example we can run CLIP+projector on CPU and LLaMA on GPU). Also, it enables us to use 4-bit GPTQ quantization for LLaVA, massively cutting down the VRAM requirement (it should be possible to fit on 12GB of VRAM with full context size by moving CLIP and projector to CPU).
### Usage through API
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Python example:
```Python
import base64
import requests
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
with open('extreme_ironing.jpg', 'rb') as f:
img_str = base64.b64encode(f.read()).decode('utf-8')
prompt = CONTEXT + f'### Human: \nWhat is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">\n### Assistant: \n'
print(requests.post('http://127.0.0.1:5000/api/v1/generate', json={'prompt': prompt, 'stopping_strings': ['\n###']}).json())
```
script output:
```Python
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
```

View file

@ -1,272 +1,6 @@
import base64
import re
import time
from dataclasses import dataclass
from functools import partial
from io import BytesIO
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel
from modules import shared
from modules.extensions import apply_extensions
from modules.text_generation import encode, get_max_prompt_length
params = {
"add_all_images_to_prompt": False,
# device to run CLIP on
"clip_device": None,
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
"clip_bits": 32,
# clip repository
"clip_repo": "openai/clip-vit-large-patch14",
# device to run projector on
"projector_device": None,
# projector bits, either 32 or 16
"projector_bits": 32,
# projector repository
"projector_repo": "liuhaotian/LLaVA-13b-delta-v0",
# file with the projector weights
"projector_file": "mm_projector.bin"
}
# If 'state' is True, will hijack the next chat generation
input_hijack = {
'state': False,
'value': ["", ""]
}
# initialized in ui, so that params are loaded from settings
llava_embedder = None
@dataclass
class Token:
token: str
id: int
class LLaVAEmbedder:
IM_PATCH = Token("<im_patch>", 32000)
IM_START = Token("<im_start>", 32001)
IM_END = Token("<im_end>", 32002)
def __init__(self):
self.clip_device = self._get_device("clip_device")
self.clip_dtype = self._get_dtype("clip_bits")
self.projector_device = self._get_device("projector_device")
self.projector_dtype = self._get_dtype("projector_bits")
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
def _get_device(self, setting_name):
if params[setting_name] is None:
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return torch.device(params[setting_name])
def _get_dtype(self, setting_name):
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
def _load_models(self):
start_ts = time.time()
print(f"LLaVA - Loading CLIP from {params['clip_repo']} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype).to(self.clip_device)
print(f"LLaVA - Loading projector from {params['projector_repo']} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(params["projector_repo"], params["projector_file"])
mm_projector = torch.nn.Linear(1024, 5120)
projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
mm_projector = mm_projector.to(self.projector_device)
print(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector
def _update_prompt(self, prompt, images):
for _ in images:
# replace the image token with the image patch token in the prompt (each occurrence)
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
return prompt
def _extract_image_features(self, images):
images = self.image_processor(images, return_tensors='pt')['pixel_values']
images = images.to(self.clip_device, dtype=self.clip_dtype)
with torch.no_grad():
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
select_hidden_state_layer = -2
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
image_features = self.mm_projector(image_features)
return image_features
def forward(self, prompt, images, state):
prompt = self._update_prompt(prompt, images)
input_ids = encode(prompt, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))[0]
input_embeds = shared.model.model.embed_tokens(input_ids).to(self.projector_device)
if input_ids[0] == LLaVAEmbedder.IM_PATCH.id:
# prompt got truncated in the middle of an image, remove the image data
im_end = torch.where(input_ids == LLaVAEmbedder.IM_END.id)[0][0]
input_ids = input_ids[im_end+1:]
input_embeds = input_embeds[im_end+1:]
leftover_images = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0].shape[0]
print(f"LLaVA - WARNING: removed {len(images) - leftover_images} image(s) from prompt. The generation might be broken, try decreasing max_new_tokens")
images = images[-leftover_images:]
if len(images) == 0:
return prompt, input_ids, input_embeds, 0
total_embedded = 0
image_features = self._extract_image_features(images).to(self.projector_device)
image_start_tokens = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0]
if not torch.any(input_ids == LLaVAEmbedder.IM_PATCH.id) or len(image_start_tokens) == 0:
# multimodal LLM, but the current prompt is not multimodal/truncated
return prompt, input_ids, input_embeds, total_embedded
cur_image_idx = 0
if not params['add_all_images_to_prompt']:
image_start_tokens = [image_start_tokens[-1]]
cur_image_idx = -1
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
input_embeds = torch.cat((input_embeds[:image_start_token_pos+1], cur_image_features, input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
total_embedded += 1
return prompt, input_ids, input_embeds, total_embedded
@staticmethod
def len_in_tokens(text):
images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
image_tokens = 0
for _ in images:
image_tokens += 258
return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224))
longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w,h))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None:
visible_text = text
elif '<image>' in visible_text:
visible_text = visible_text.replace('<image>', image)
else:
visible_text = visible_text + '\n' + image
return text, visible_text
def custom_generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{state['context'].strip()}\n"]
min_rows = 3
# Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
prefix1 = f"{state['name1']}: "
prefix2 = f"{state['name2']}: "
i = len(shared.history['internal']) - 1
while i >= 0 and LLaVAEmbedder.len_in_tokens(''.join(rows)) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}\n")
string = shared.history['internal'][i][0]
if string != '':
rows.insert(1, f"{prefix1}{string.strip()}\n")
i -= 1
if impersonate:
min_rows = 2
rows.append(f"{prefix1}")
elif not _continue:
# Adding the user message
if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}\n")
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", f"{prefix2}"))
while len(rows) > min_rows and LLaVAEmbedder.len_in_tokens(''.join(rows)) >= max_length:
rows.pop(1)
prompt = ''.join(rows)
if also_return_rows:
return prompt, rows
else:
return prompt
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params
start_ts = time.time()
image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
if len(images) == 0:
return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
import logging
def ui():
global llava_embedder
llava_embedder = LLaVAEmbedder()
with gr.Column():
picture_select = gr.Image(label='Send a picture', type='pil')
# I found that it doesn't deal super well with multiple images, and demo ui had a bug where it included only the last image anyway
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
# Prepare the input hijack
picture_select.upload(
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
[picture_select],
None
)
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")

View file

@ -0,0 +1,85 @@
# Technical description of multimodal extension
## Working principle
Multimodality extension does most of the stuff which is required for any image input:
- adds the UI
- saves the images as base64 JPEGs to history
- provides the hooks to the UI
- if there are images in the prompt, it:
- splits the prompt to text and image parts
- adds image start/end markers to text parts, then encodes and embeds the text parts
- calls the vision pipeline to embed the images
- stitches the embeddings together, and returns them to text generation
- loads the appropriate vision pipeline, selected either from model name, or by specifying --multimodal-pipeline parameter
Now, for the pipelines, they:
- load the required vision models
- return some consts, for example the number of tokens taken up by image
- and most importantly: return the embeddings for LLM, given a list of images
## Prompts/history
To save images in prompt/history, this extension is using a base64 JPEG, wrapped in a HTML tag, like so:
```
<img src="data:image/jpeg;base64,{img_str}">
```
where `{img_str}` is the actual image data. This format makes displaying them in the UI for free. Do note, that this format is required to be exactly the same, the regex used to find the images is: `<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">`.
## LLM input
To describe the input, let's see it on an example prompt:
```
text1<image1>text2<image2>text3
```
where `textN` is N-th text, `<imageN>` is N-th image, in HTML format specified above.
**The first step is to split the prompt into image/text parts**, so we get:
```
['text1', '<image1>', 'text2', '<image2>', 'text3']
```
this is done in `MultimodalEmbedder._split_prompt(...)` function, which returns a list of `PromptPart`s - dataclasses wrapping the separate parts.
This function also appends the image start/end markers to text, which are provided by `AbstractMultimodalPipeline.image_start()` / `AbstractMultimodalPipeline.image_end()` functions. If image start is `<Img>`, and end is `</Img>`, this function will return:
```
['text1<Img>', '<image1>', '</Img>text2<Img>', '<image2>', '</Img>text3']
```
**The returned prompt parts are then turned into token embeddings.**
First, they are modified to token IDs, for the text it is done using standard `modules.text_generation.encode()` function, and for the images the returned token IDs are changed to placeholders. The placeholder is a list of `N` times `placeholder token id`, where `N` is specified using `AbstractMultimodalPipeline.num_image_embeds()`, and placeholder token IDs using `AbstractMultimodalPipeline.placeholder_token_id()`.
Now, based on the token IDs, the prompt might get truncated, especially if `max_new_tokens` are unreasonably high. Unfortunately, it can't be done simply, just by trimming the prompt to be short enough. This way will lead to sometimes splitting the prompt in the middle of an image embedding, which usually breaks the generation. Therefore, in this case, the entire image needs to be removed from input. This is done inside `MultimodalEmbedder._encode_text(...)` function.
**After the tokenization, the tokens need to get embedded**, the text and images are once again treated separately.
The text parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_tokens(...)` function. It uses standard embedding function from the model, but to support many LLMs, the actual function is returned by the pipeline (as it might be different for different LLMs), for LLaMA it is `shared.model.model.embed_tokens(...)`.
The image parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_images(...)` function. This function is specific for a given pipeline, it takes the images as input, forwards them through vision model/projector, and returns the embeddings.
**Now, the returned embeddings are stitched together**, using `torch.cat()`, this is creating the final input to the LLM.
## Pipelines
All of the pipelines should subclass `AbstractMultimodalPipeline` class. The idea is to allow for new pipelines to be added in the same way as user extensions - git clone into `extensions/multimodal/pipelines`.
The pipelines are the description of the vision part, containing vision model/multimodal projector. All of the pipelines should have an unique `name()`, which is then selected by user, in `--multimodal-pipeline` CLI argument. For an example, see `pipelines/llava/llava.py`.
## Pipeline modules
Pipelines are organized into "pipeline modules" - subdirectories in `pipelines` directory. The pipeline modules should contain a file called `pipelines.py`, that should contain the following fields:
- `available_pipelines: List[str]` - list of pipelines provided by this module, shown as the list of available pipelines to the user
- `def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a concrete pipeline by `name`, if `name` doesn't match any, should return `None`. `params` is the user settings for multimodal extension
- `def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a pipeline from `model_name`, should be eager to return `None`, unless the determination can be done clearly (for example: minigpt-4 bases on vicuna - it should never return the pipeline, but llava can, as it has its own specific LLM finetune)
**NOTE**: A pipeline module should lazy-import the pipelines only when necessary, and it should keep its imports to minimum
## Pipeline params
The pipelines will get the extension `params` in the constructor. They should honor the following fields:
- `vision_device` - string, specifying `torch.device` to run the vision model (CLIP/ViT) on
- `vision_bits` - int, number of fp bits to load the vision model(s) in
- `projector_device` - string, specifying `torch.device` to run the projector models (Linear layers, QFormer, etc.) on
- `projector_bits` - int, number of fp bits to load the projector models in
As a helper, `AbstractMultimodalPipeline` has `_get_device(self, setting_name: str, params: dict)` and `_get_dtype(self, setting_name: str, params: dict)` helper functions, which parse string/int and return `torch.device` / `torch.dtype`.

View file

@ -0,0 +1,78 @@
# Multimodal
## Description
Adds support for multimodality (text+images) to text-generation-webui.
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
## Usage
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
```
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat
python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat
```
There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`:
- clone https://github.com/Wojtab/minigpt-4-pipeline into `extensions/multimodal/pipelines`
- install the requirements.txt
The same procedure should be used to install other pipelines, which can then me used with `--multimodal-pipeline [pipeline name]`. For additional multimodal pipelines refer to compatibility section below.
Do note, that each image takes up a considerable amount of tokens, so adjust `max_new_tokens` to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. It seems as some multimodal networks consider the features in all images at the same time as if they were a single image. Due to this behavior, by default the extension skips previous images. However, it can lead to sub-par generation on other pipelines. If you want to include all images, just tick this checkbox.
## Compatibility
As of now, the following multimodal pipelines are supported:
|Pipeline|`--multimodal-pipeline`|Default LLM|LLM info(for the linked model)|Pipeline repository|
|-|-|-|-|-|
|[LLaVA 13B](https://github.com/haotian-liu/LLaVA)|`llava-13b`|[LLaVA 13B](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
Some pipelines could support different LLMs, but do note that while it might work, it isn't a supported configuration.
DO NOT report bugs if you are using a different LLM.
DO NOT report bugs with pipelines in this repository (unless they are built-in)
## Extension config
This extension uses following parameters (from settings.json):
|Parameter|Description|
|---------|-----------|
|`multimodal-vision_bits`|Number of bits to load vision models (CLIP/ViT) feature extractor in (most pipelines should support either 32 or 16, default=32)|
|`multimodal-vision_device`|Torch device to run the feature extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`multimodal-projector_bits`|Number of bits to load feature projector model(s) in (most pipelines should support either 32 or 16, default=32)|
|`multimodal-projector_device`|Torch device to run the feature projector model(s) on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`multimodal-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
## Usage through API
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Python example:
```Python
import base64
import requests
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.### Human: Hi!### Assistant: Hi there! How can I help you today?\n"
with open('extreme_ironing.jpg', 'rb') as f:
img_str = base64.b64encode(f.read()).decode('utf-8')
prompt = CONTEXT + f'### Human: What is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">### Assistant: '
print(requests.post('http://127.0.0.1:5000/api/v1/generate', json={'prompt': prompt, 'stopping_strings': ['\n###']}).json())
```
script output:
```Python
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
```
## For pipeline developers/technical description
see [DOCS.md](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/multimodal/DOCS.md)

View file

@ -0,0 +1,62 @@
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
from PIL import Image
class AbstractMultimodalPipeline(ABC):
@staticmethod
@abstractmethod
def name() -> str:
'name of the pipeline, should be same as in --multimodal-pipeline'
pass
@staticmethod
@abstractmethod
def image_start() -> Optional[str]:
'return image start string, string representation of image start token, or None if not applicable'
pass
@staticmethod
@abstractmethod
def image_end() -> Optional[str]:
'return image end string, string representation of image end token, or None if not applicable'
pass
@staticmethod
@abstractmethod
def placeholder_token_id() -> int:
'return placeholder token id'
pass
@staticmethod
@abstractmethod
def num_image_embeds() -> int:
'return the number of embeds used by a single image (for example: 256 for LLaVA)'
pass
@abstractmethod
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
'forward the images through vision pipeline, and return their embeddings'
pass
@staticmethod
@abstractmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
'embed tokens, the exact function varies by LLM, for LLaMA it is `shared.model.model.embed_tokens`'
pass
@staticmethod
@abstractmethod
def placeholder_embeddings() -> torch.Tensor:
'get placeholder embeddings if there are multiple images, and `add_all_images_to_prompt` is False'
pass
def _get_device(self, setting_name: str, params: dict):
if params[setting_name] is None:
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return torch.device(params[setting_name])
def _get_dtype(self, setting_name: str, params: dict):
return torch.float32 if int(params[setting_name]) == 32 else torch.float16

View file

@ -0,0 +1,177 @@
import base64
import logging
import re
from dataclasses import dataclass
from io import BytesIO
from typing import Any, List, Optional
import torch
from extensions.multimodal.pipeline_loader import load_pipeline
from modules import shared
from modules.text_generation import encode, get_max_prompt_length
from PIL import Image
@dataclass
class PromptPart:
text: str
image: Optional[Image.Image] = None
is_image: bool = False
input_ids: Optional[torch.Tensor] = None
embedding: Optional[torch.Tensor] = None
class MultimodalEmbedder:
def __init__(self, params: dict):
pipeline, source = load_pipeline(params)
self.pipeline = pipeline
logging.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
"""Splits a prompt into a list of `PromptParts` to separate image data from text.
It will also append `image_start` and `image_end` before and after the image, and optionally parse and load the images,
if `load_images` is `True`.
"""
parts: List[PromptPart] = []
curr = 0
while True:
match = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt[curr:])
if match is None:
# no more image tokens, append the rest of the prompt
if curr > 0:
# add image end token after last image
parts.append(PromptPart(text=self.pipeline.image_end() + prompt[curr:]))
else:
parts.append(PromptPart(text=prompt))
break
# found an image, append image start token to the text
if match.start() > 0:
parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start()))
else:
parts.append(PromptPart(text=self.pipeline.image_start()))
# append the image
parts.append(PromptPart(
text=match.group(0),
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
is_image=True
))
curr += match.end()
return parts
def _len_in_tokens_prompt_parts(self, parts: List[PromptPart]) -> int:
"""Total length in tokens of all `parts`"""
tokens = 0
for part in parts:
if part.is_image:
tokens += self.pipeline.num_image_embeds()
elif part.input_ids is not None:
tokens += len(part.input_ids)
else:
tokens += len(encode(part.text)[0])
return tokens
def len_in_tokens(self, prompt: str) -> int:
"""Total length in tokens for a given text `prompt`"""
parts = self._split_prompt(prompt, False)
return self._len_in_tokens_prompt_parts(parts)
def _encode_single_text(self, part: PromptPart, add_bos_token: bool) -> PromptPart:
"""Encode a single prompt `part` to `input_ids`. Returns a `PromptPart`"""
if part.is_image:
placeholders = torch.ones((self.pipeline.num_image_embeds())) * self.pipeline.placeholder_token_id()
part.input_ids = placeholders.to(shared.model.device, dtype=torch.int64)
else:
part.input_ids = encode(part.text, add_bos_token=add_bos_token)[0].to(shared.model.device, dtype=torch.int64)
return part
@staticmethod
def _num_images(parts: List[PromptPart]) -> int:
count = 0
for part in parts:
if part.is_image:
count += 1
return count
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
"""Encode text to token_ids, also truncate the prompt, if necessary.
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
such that the context + min_rows don't fit, we can get a prompt which is too long.
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
"""
encoded: List[PromptPart] = []
for i, part in enumerate(parts):
encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token']))
# truncation:
max_len = get_max_prompt_length(state)
removed_images = 0
# 1. remove entire text/image blocks
while self._len_in_tokens_prompt_parts(encoded[1:]) > max_len:
if encoded[0].is_image:
removed_images += 1
encoded = encoded[1:]
# 2. check if the last prompt part doesn't need to get truncated
if self._len_in_tokens_prompt_parts(encoded) > max_len:
if encoded[0].is_image:
# don't truncate image embeddings, just remove the image, otherwise generation will be broken
removed_images += 1
encoded = encoded[1:]
elif len(encoded) > 1 and encoded[0].text.endswith(self.pipeline.image_start()):
# see if we can keep image_start token
len_image_start = len(encode(self.pipeline.image_start(), add_bos_token=state['add_bos_token'])[0])
if self._len_in_tokens_prompt_parts(encoded[1:]) + len_image_start > max_len:
# we can't -> remove this text, and the image
encoded = encoded[2:]
removed_images += 1
else:
# we can -> just truncate the text
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
elif len(encoded) > 0:
# only one text left, truncate it normally
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
# notify user if we truncated an image
if removed_images > 0:
logging.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
return encoded
def _embed(self, parts: List[PromptPart]) -> List[PromptPart]:
# batch images
image_indicies = [i for i, part in enumerate(parts) if part.is_image]
embedded = self.pipeline.embed_images([parts[i].image for i in image_indicies])
for i, embeds in zip(image_indicies, embedded):
parts[i].embedding = embeds
# embed text
for (i, part) in enumerate(parts):
if not part.is_image:
parts[i].embedding = self.pipeline.embed_tokens(part.input_ids)
return parts
def _remove_old_images(self, parts: List[PromptPart], params: dict) -> List[PromptPart]:
if params['add_all_images_to_prompt']:
return parts
already_added = False
for i, part in reversed(list(enumerate(parts))):
if part.is_image:
if already_added:
parts[i].embedding = self.pipeline.placeholder_embeddings()
else:
already_added = True
return parts
def forward(self, prompt: str, state: Any, params: dict):
prompt_parts = self._split_prompt(prompt, True)
prompt_parts = self._encode_text(state, prompt_parts)
prompt_parts = self._embed(prompt_parts)
prompt_parts = self._remove_old_images(prompt_parts, params)
embeds = tuple(part.embedding for part in prompt_parts)
ids = tuple(part.input_ids for part in prompt_parts)
input_embeds = torch.cat(embeds, dim=0)
input_ids = torch.cat(ids, dim=0)
return prompt, input_ids, input_embeds, self._num_images(prompt_parts)

View file

@ -0,0 +1,52 @@
import logging
import traceback
from importlib import import_module
from pathlib import Path
from typing import Tuple
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from modules import shared
def _get_available_pipeline_modules():
pipeline_path = Path(__file__).parent / 'pipelines'
modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
return [m.name for m in modules if (m / 'pipelines.py').exists()]
def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
pipeline_modules = {}
available_pipeline_modules = _get_available_pipeline_modules()
for name in available_pipeline_modules:
try:
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
except:
logging.warning(f'Failed to get multimodal pipelines from {name}')
logging.warning(traceback.format_exc())
if shared.args.multimodal_pipeline is not None:
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'get_pipeline'):
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
if pipeline is not None:
return (pipeline, k)
else:
model_name = shared.args.model.lower()
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
if pipeline is not None:
return (pipeline, k)
available = []
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'available_pipelines'):
pipelines = getattr(pipeline_modules[k], 'available_pipelines')
available += pipelines
if shared.args.multimodal_pipeline is not None:
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
else:
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
logging.critical(f'{log} Please specify a correct pipeline, or disable the extension')
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')

View file

@ -0,0 +1,9 @@
## LLaVA pipeline
This module provides 2 pipelines:
- `llava-7b` - for use with LLaVA v0 7B model (finetuned LLaMa 7B)
- `llava-13b` - for use with LLaVA v0 13B model (finetuned LLaMa 13B)
[LLaVA](https://github.com/haotian-liu/LLaVA) uses CLIP `openai/clip-vit-large-patch14` as the vision model, and then a single linear layer. For 13B the projector weights are in `liuhaotian/LLaVA-13b-delta-v0`, and for 7B they are in `liuhaotian/LLaVA-7b-delta-v0`.
The supported parameter combinations for both the vision model, and the projector are: CUDA/32bit, CUDA/16bit, CPU/32bit

View file

@ -0,0 +1,139 @@
import logging
import time
from abc import abstractmethod
from typing import List, Tuple
import torch
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from huggingface_hub import hf_hub_download
from modules import shared
from modules.text_generation import encode
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
CLIP_REPO = "openai/clip-vit-large-patch14"
def __init__(self, params: dict) -> None:
super().__init__()
self.clip_device = self._get_device("vision_device", params)
self.clip_dtype = self._get_dtype("vision_bits", params)
self.projector_device = self._get_device("projector_device", params)
self.projector_dtype = self._get_dtype("projector_bits", params)
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
def _load_models(self):
start_ts = time.time()
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
mm_projector = mm_projector.to(self.projector_device)
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector
@staticmethod
def image_start() -> str:
return "<im_start>"
@staticmethod
def image_end() -> str:
return "<im_end>"
@staticmethod
def num_image_embeds() -> int:
return 256
@staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
@staticmethod
def placeholder_embeddings() -> torch.Tensor:
return LLaVA_v0_Pipeline.embed_tokens(encode("<im_patch>"*256, add_bos_token=False)[0])
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
images = self.image_processor(images, return_tensors='pt')['pixel_values']
images = images.to(self.clip_device, dtype=self.clip_dtype)
with torch.no_grad():
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
select_hidden_state_layer = -2
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
image_features = self.mm_projector(image_features)
return image_features.to(shared.model.device, dtype=shared.model.dtype)
@staticmethod
@abstractmethod
def llava_projector_repo() -> str:
pass
@staticmethod
@abstractmethod
def llava_projector_filename() -> str:
pass
@staticmethod
@abstractmethod
def llava_projector_shape() -> Tuple[int, int]:
pass
class LLaVA_v0_13B_Pipeline(LLaVA_v0_Pipeline):
def __init__(self, params: dict) -> None:
super().__init__(params)
@staticmethod
def name() -> str:
return "llava-13b"
@staticmethod
def placeholder_token_id() -> int:
return 32000
@staticmethod
def llava_projector_shape() -> Tuple[int, int]:
return (1024, 5120)
@staticmethod
def llava_projector_filename() -> str:
return "mm_projector.bin"
@staticmethod
def llava_projector_repo() -> str:
return "liuhaotian/LLaVA-13b-delta-v0"
class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
def __init__(self, params: dict) -> None:
super().__init__(params)
@staticmethod
def name() -> str:
return "llava-7b"
@staticmethod
def placeholder_token_id() -> int:
return 32001
@staticmethod
def llava_projector_shape() -> Tuple[int, int]:
return (1024, 4096)
@staticmethod
def llava_projector_filename() -> str:
return "mm_projector.bin"
@staticmethod
def llava_projector_repo() -> str:
return "liuhaotian/LLaVA-7b-delta-v0"

View file

@ -0,0 +1,27 @@
from typing import Optional
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
available_pipelines = ['llava-7b', 'llava-13b']
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
if name == 'llava-7b':
from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params)
if name == 'llava-13b':
from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params)
return None
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
if 'llava' not in model_name.lower():
return None
if '7b' in model_name.lower():
from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params)
if '13b' in model_name.lower():
from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params)
return None

View file

@ -0,0 +1,103 @@
import base64
import logging
import re
import time
from functools import partial
from io import BytesIO
import gradio as gr
import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared
params = {
"add_all_images_to_prompt": False,
# device to run vision encoder on
"vision_device": None,
# bits to load vision encoder in, either 16 or 32
"vision_bits": 32,
# device to run multimodal projector on
"projector_device": None,
# multimodal projector bits, either 32 or 16
"projector_bits": 32
}
# If 'state' is True, will hijack the next chat generation
input_hijack = {
'state': False,
'value': ["", ""]
}
# initialized in ui, so that params are loaded from settings
multimodal_embedder: MultimodalEmbedder = None
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224))
longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w,h))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None:
visible_text = text
elif '<image>' in visible_text:
visible_text = visible_text.replace('<image>', image)
else:
visible_text = visible_text + '\n' + image
return text, visible_text
def custom_tokenized_length(prompt):
return multimodal_embedder.len_in_tokens(prompt)
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params
start_ts = time.time()
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
if image_match is None:
return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
def ui():
global multimodal_embedder
multimodal_embedder = MultimodalEmbedder(params)
with gr.Column():
picture_select = gr.Image(label='Send a picture', type='pil')
# The models don't seem to deal well with multiple images
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
# Prepare the input hijack
picture_select.upload(
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
[picture_select],
None
)
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select)

View file

@ -14,7 +14,7 @@ from PIL import Image
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper, make_thumbnail
from modules.text_generation import (encode, generate_reply,
from modules.text_generation import (generate_reply, get_encoded_length,
get_max_prompt_length)
@ -67,7 +67,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Building the prompt
i = len(history) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
while i >= 0 and get_encoded_length(''.join(rows)) < max_length:
if _continue and i == len(history) - 1:
rows.insert(1, bot_turn_stripped + history[i][1].strip())
else:
@ -90,7 +90,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
while len(rows) > min_rows and get_encoded_length(''.join(rows)) >= max_length:
rows.pop(1)
prompt = ''.join(rows)

View file

@ -7,6 +7,7 @@ import gradio as gr
import extensions
import modules.shared as shared
state = {}
available_extensions = []
setup_called = set()
@ -73,14 +74,11 @@ def _apply_input_hijack(text, visible_text):
return text, visible_text
# custom_generate_chat_prompt handling
# custom_generate_chat_prompt handling - currently only the first one will work
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
custom_generate_chat_prompt = None
for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None:
if hasattr(extension, 'custom_generate_chat_prompt'):
return custom_generate_chat_prompt(text, state, **kwargs)
return None
@ -95,16 +93,26 @@ def _apply_state_modifier_extensions(state):
return state
# Extension functions that override the default tokenizer output
# Extension functions that override the default tokenizer output - currently only the first one will work
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
# Custom generate reply handling
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
# currently only the first one will work
def _apply_custom_tokenized_length(prompt):
for extension, _ in iterator():
if hasattr(extension, 'custom_tokenized_length'):
return getattr(extension, 'custom_tokenized_length')(prompt)
return None
# Custom generate reply handling - currently only the first one will work
def _apply_custom_generate_reply():
for extension, _ in iterator():
if hasattr(extension, 'custom_generate_reply'):
@ -121,7 +129,8 @@ EXTENSION_MAP = {
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
"custom_generate_reply": _apply_custom_generate_reply
"custom_generate_reply": _apply_custom_generate_reply,
"tokenized_length": _apply_custom_tokenized_length
}

View file

@ -166,6 +166,8 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
# Multimodal
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
args = parser.parse_args()
args_defaults = parser.parse_args([])
@ -183,12 +185,21 @@ if args.trust_remote_code:
if args.share:
logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
def add_extension(name):
if args.extensions is None:
args.extensions = [name]
elif 'api' not in args.extensions:
args.extensions.append(name)
# Activating the API extension
if args.api or args.public_api:
if args.extensions is None:
args.extensions = ['api']
elif 'api' not in args.extensions:
args.extensions.append('api')
add_extension('api')
# Activating the multimodal extension
if args.multimodal_pipeline is not None:
add_extension('multimodal')
def is_chat():

View file

@ -59,6 +59,14 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
return input_ids.cuda()
def get_encoded_length(prompt):
length_after_extensions = apply_extensions('tokenized_length', prompt)
if length_after_extensions is not None:
return length_after_extensions
return len(encode(prompt)[0])
def decode(output_ids, skip_special_tokens=True):
return shared.tokenizer.decode(output_ids, skip_special_tokens)

View file

@ -48,7 +48,7 @@ from modules import chat, shared, training, ui, utils
from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import encode, generate_reply, stop_everything_event
from modules.text_generation import generate_reply, get_encoded_length, stop_everything_event
def load_model_wrapper(selected_model, autoload=False):
@ -140,7 +140,7 @@ def load_prompt(fname):
def count_tokens(text):
tokens = len(encode(text)[0])
tokens = get_encoded_length(text)
return f'{tokens} tokens in the input.'