Style changes

This commit is contained in:
oobabooga 2023-05-09 20:20:35 -03:00
parent e9e75a9ec7
commit 8fa5f651d6
3 changed files with 9 additions and 10 deletions

View file

@ -46,7 +46,7 @@ class MultimodalEmbedder:
break break
# found an image, append image start token to the text # found an image, append image start token to the text
if match.start() > 0: if match.start() > 0:
parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start())) parts.append(PromptPart(text=prompt[curr:curr + match.start()] + self.pipeline.image_start()))
else: else:
parts.append(PromptPart(text=self.pipeline.image_start())) parts.append(PromptPart(text=self.pipeline.image_start()))
# append the image # append the image
@ -101,7 +101,7 @@ class MultimodalEmbedder:
""" """
encoded: List[PromptPart] = [] encoded: List[PromptPart] = []
for i, part in enumerate(parts): for i, part in enumerate(parts):
encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token'])) encoded.append(self._encode_single_text(part, i == 0 and state['add_bos_token']))
# truncation: # truncation:
max_len = get_max_prompt_length(state) max_len = get_max_prompt_length(state)

View file

@ -42,14 +42,13 @@ def add_chat_picture(picture, text, visible_text):
longest_edge = int(shortest_edge * aspect_ratio) longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w,h)) picture = picture.resize((w, h))
buffer = BytesIO() buffer = BytesIO()
picture.save(buffer, format="JPEG") picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">' image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text: if '<image>' in text:
text = text.replace('<image>', image) text = text.replace('<image>', image)
else: else:
@ -97,7 +96,7 @@ def ui():
[picture_select], [picture_select],
None None
) )
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None) picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select) shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select) shared.gradio['textbox'].submit(lambda: None, None, picture_select)