Style changes

This commit is contained in:
oobabooga 2023-05-09 20:20:35 -03:00
parent e9e75a9ec7
commit 8fa5f651d6
3 changed files with 9 additions and 10 deletions

View file

@ -46,12 +46,12 @@ class MultimodalEmbedder:
break break
# found an image, append image start token to the text # found an image, append image start token to the text
if match.start() > 0: if match.start() > 0:
parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start())) parts.append(PromptPart(text=prompt[curr:curr + match.start()] + self.pipeline.image_start()))
else: else:
parts.append(PromptPart(text=self.pipeline.image_start())) parts.append(PromptPart(text=self.pipeline.image_start()))
# append the image # append the image
parts.append(PromptPart( parts.append(PromptPart(
text=match.group(0), text=match.group(0),
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None, image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
is_image=True is_image=True
)) ))
@ -94,14 +94,14 @@ class MultimodalEmbedder:
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]: def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
"""Encode text to token_ids, also truncate the prompt, if necessary. """Encode text to token_ids, also truncate the prompt, if necessary.
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
such that the context + min_rows don't fit, we can get a prompt which is too long. such that the context + min_rows don't fit, we can get a prompt which is too long.
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
""" """
encoded: List[PromptPart] = [] encoded: List[PromptPart] = []
for i, part in enumerate(parts): for i, part in enumerate(parts):
encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token'])) encoded.append(self._encode_single_text(part, i == 0 and state['add_bos_token']))
# truncation: # truncation:
max_len = get_max_prompt_length(state) max_len = get_max_prompt_length(state)

View file

@ -26,7 +26,7 @@ def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
if shared.args.multimodal_pipeline is not None: if shared.args.multimodal_pipeline is not None:
for k in pipeline_modules: for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'get_pipeline'): if hasattr(pipeline_modules[k], 'get_pipeline'):
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params) pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
if pipeline is not None: if pipeline is not None:
return (pipeline, k) return (pipeline, k)

View file

@ -42,14 +42,13 @@ def add_chat_picture(picture, text, visible_text):
longest_edge = int(shortest_edge * aspect_ratio) longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w,h)) picture = picture.resize((w, h))
buffer = BytesIO() buffer = BytesIO()
picture.save(buffer, format="JPEG") picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">' image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text: if '<image>' in text:
text = text.replace('<image>', image) text = text.replace('<image>', image)
else: else:
@ -80,8 +79,8 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds):
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params) prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt, return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
def ui(): def ui():
@ -97,7 +96,7 @@ def ui():
[picture_select], [picture_select],
None None
) )
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None) picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select) shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select) shared.gradio['textbox'].submit(lambda: None, None, picture_select)