diff --git a/packages/core/src/utils/getSchema.ts b/packages/core/src/utils/getSchema.ts index 0bf7251f..a892d470 100644 --- a/packages/core/src/utils/getSchema.ts +++ b/packages/core/src/utils/getSchema.ts @@ -1,5 +1,7 @@ -import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model' -import { Extensions } from '../types' +import { + NodeSpec, MarkSpec, Schema, ParseRule, +} from 'prosemirror-model' +import { ExtensionAttribute, Extensions } from '../types' import splitExtensions from './splitExtensions' import getAttributesFromExtensions from './getAttributesFromExtensions' import getRenderedAttributes from './getRenderedAttributes' @@ -15,18 +17,35 @@ function cleanUpSchemaItem(data: T) { })) as T } +function injectExtensionAttributes(parseRule: ParseRule, extensionAttributes: ExtensionAttribute[]): ParseRule { + if (parseRule.style) { + return parseRule + } + + return { + ...parseRule, + getAttrs: node => { + const oldAttributes = parseRule.getAttrs ? parseRule.getAttrs(node) : {} + const newAttributes = extensionAttributes + .filter(item => item.attribute.rendered) + .reduce((items, item) => ({ + ...items, + ...item.attribute.parseHTML(node as HTMLElement), + }), {}) + + return { ...oldAttributes, ...newAttributes } + }, + } +} + export default function getSchema(extensions: Extensions): Schema { const allAttributes = getAttributesFromExtensions(extensions) const { nodeExtensions, markExtensions } = splitExtensions(extensions) const topNode = nodeExtensions.find(extension => extension.topNode)?.name const nodes = Object.fromEntries(nodeExtensions.map(extension => { - const context = { - options: extension.options, - } - const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.name) - + const context = { options: extension.options } const schema: NodeSpec = cleanUpSchemaItem({ content: extension.content, marks: extension.marks, @@ -41,46 +60,27 @@ export default function getSchema(extensions: Extensions): Schema { attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => { return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }] })), - parseDOM: extension.parseHTML.bind(context)()?.map(parseRule => { - if (parseRule.style) { - return parseRule - } - - return { - ...parseRule, - getAttrs: node => { - const oldAttributes = parseRule.getAttrs ? parseRule.getAttrs(node) : {} - const newAttributes = extensionAttributes - .filter(item => item.attribute.rendered) - .reduce((items, item) => ({ - ...items, - ...item.attribute.parseHTML(node as HTMLElement), - }), {}) - - return { ...oldAttributes, ...newAttributes } - }, - } - }), - ...(typeof extension.renderHTML === 'function') && { - toDOM: node => { - return (extension.renderHTML as Function).bind(context)({ - node, - attributes: getRenderedAttributes(node, extensionAttributes), - }) - }, - }, }) + if (extension.parseHTML) { + schema.parseDOM = extension.parseHTML + .bind(context)() + ?.map(parseRule => injectExtensionAttributes(parseRule, extensionAttributes)) + } + + if (extension.renderHTML) { + schema.toDOM = node => (extension.renderHTML as Function)?.bind(context)({ + node, + attributes: getRenderedAttributes(node, extensionAttributes), + }) + } + return [extension.name, schema] })) const marks = Object.fromEntries(markExtensions.map(extension => { - const context = { - options: extension.options, - } - const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.name) - + const context = { options: extension.options } const schema: MarkSpec = cleanUpSchemaItem({ inclusive: extension.inclusive, excludes: extension.excludes, @@ -89,17 +89,21 @@ export default function getSchema(extensions: Extensions): Schema { attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => { return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }] })), - parseDOM: extension.parseHTML.bind(context)(), - ...(typeof extension.renderHTML === 'function') && { - toDOM: mark => { - return (extension.renderHTML as Function).bind(context)({ - mark, - attributes: getRenderedAttributes(mark, extensionAttributes), - }) - }, - }, }) + if (extension.parseHTML) { + schema.parseDOM = extension.parseHTML + .bind(context)() + ?.map(parseRule => injectExtensionAttributes(parseRule, extensionAttributes)) + } + + if (extension.renderHTML) { + schema.toDOM = mark => (extension.renderHTML as Function)?.bind(context)({ + mark, + attributes: getRenderedAttributes(mark, extensionAttributes), + }) + } + return [extension.name, schema] }))