add basic schema extender

This commit is contained in:
Philipp Kühn
2021-02-19 09:54:39 +01:00
parent 85eceb32b8
commit 6f9557294e
7 changed files with 88 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ import { Plugin, Transaction } from 'prosemirror-state'
import { Command as ProseMirrorCommand } from 'prosemirror-commands' import { Command as ProseMirrorCommand } from 'prosemirror-commands'
import { InputRule } from 'prosemirror-inputrules' import { InputRule } from 'prosemirror-inputrules'
import { Editor } from './Editor' import { Editor } from './Editor'
import { Node } from './Node'
import mergeDeep from './utilities/mergeDeep' import mergeDeep from './utilities/mergeDeep'
import { GlobalAttributes, RawCommands } from './types' import { GlobalAttributes, RawCommands } from './types'
@@ -65,6 +66,30 @@ export interface ExtensionConfig<Options = any> {
editor: Editor, editor: Editor,
}) => Plugin[], }) => Plugin[],
/**
* Extend Node Schema
*/
extendNodeSchema?: ((
this: {
options: Options,
},
extension: Node,
) => {
[key: string]: any,
}) | null,
/**
* Extend Mark Schema
*/
extendMarkSchema?: ((
this: {
options: Options,
},
extension: Node,
) => {
[key: string]: any,
}) | null,
/** /**
* The editor is ready. * The editor is ready.
*/ */
@@ -149,6 +174,8 @@ export class Extension<Options = any> {
addInputRules: () => [], addInputRules: () => [],
addPasteRules: () => [], addPasteRules: () => [],
addProseMirrorPlugins: () => [], addProseMirrorPlugins: () => [],
extendNodeSchema: null,
extendMarkSchema: null,
onCreate: null, onCreate: null,
onUpdate: null, onUpdate: null,
onSelection: null, onSelection: null,

View File

@@ -209,6 +209,8 @@ export class Mark<Options = any> {
parseHTML: () => null, parseHTML: () => null,
renderHTML: null, renderHTML: null,
addAttributes: () => ({}), addAttributes: () => ({}),
extendNodeSchema: null,
extendMarkSchema: null,
onCreate: null, onCreate: null,
onUpdate: null, onUpdate: null,
onSelection: null, onSelection: null,

View File

@@ -70,12 +70,6 @@ export interface NodeConfig<Options = any> extends Overwrite<ExtensionConfig<Opt
*/ */
isolating?: NodeSpec['isolating'] | ((this: { options: Options }) => NodeSpec['isolating']), isolating?: NodeSpec['isolating'] | ((this: { options: Options }) => NodeSpec['isolating']),
// TODO: extend via extension-table
/**
* Table Role
*/
tableRole?: NodeSpec['tableRole'] | ((this: { options: Options }) => NodeSpec['tableRole']),
/** /**
* Parse HTML * Parse HTML
*/ */
@@ -284,6 +278,8 @@ export class Node<Options = any> {
renderText: null, renderText: null,
addAttributes: () => ({}), addAttributes: () => ({}),
addNodeView: null, addNodeView: null,
extendNodeSchema: null,
extendMarkSchema: null,
onCreate: null, onCreate: null,
onUpdate: null, onUpdate: null,
onSelection: null, onSelection: null,
@@ -291,8 +287,6 @@ export class Node<Options = any> {
onFocus: null, onFocus: null,
onBlur: null, onBlur: null,
onDestroy: null, onDestroy: null,
// TODO: remove,
tableRole: null,
} }
options!: Options options!: Options

View File

@@ -1,5 +1,6 @@
import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model' import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model'
import { Extensions } from '../types' import { Extensions } from '../types'
import { ExtensionConfig } from '../Extension'
import splitExtensions from './splitExtensions' import splitExtensions from './splitExtensions'
import getAttributesFromExtensions from './getAttributesFromExtensions' import getAttributesFromExtensions from './getAttributesFromExtensions'
import getRenderedAttributes from './getRenderedAttributes' import getRenderedAttributes from './getRenderedAttributes'
@@ -21,11 +22,34 @@ export default function getSchema(extensions: Extensions): Schema {
const allAttributes = getAttributesFromExtensions(extensions) const allAttributes = getAttributesFromExtensions(extensions)
const { nodeExtensions, markExtensions } = splitExtensions(extensions) const { nodeExtensions, markExtensions } = splitExtensions(extensions)
const topNode = nodeExtensions.find(extension => extension.config.topNode)?.config.name const topNode = nodeExtensions.find(extension => extension.config.topNode)?.config.name
const nodeSchemaExtenders: ExtensionConfig['extendNodeSchema'][] = []
const markSchemaExtenders: ExtensionConfig['extendMarkSchema'][] = []
extensions.forEach(extension => {
if (typeof extension.config.extendNodeSchema === 'function') {
nodeSchemaExtenders.push(extension.config.extendNodeSchema)
}
if (typeof extension.config.extendMarkSchema === 'function') {
markSchemaExtenders.push(extension.config.extendMarkSchema)
}
})
const nodes = Object.fromEntries(nodeExtensions.map(extension => { const nodes = Object.fromEntries(nodeExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name) const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
const context = { options: extension.options } const context = { options: extension.options }
const extraNodeFields = nodeSchemaExtenders.reduce((fields, nodeSchemaExtender) => {
const extraFields = callOrReturn(nodeSchemaExtender, context, extension)
return {
...fields,
...extraFields,
}
}, {})
const schema: NodeSpec = cleanUpSchemaItem({ const schema: NodeSpec = cleanUpSchemaItem({
...extraNodeFields,
content: callOrReturn(extension.config.content, context), content: callOrReturn(extension.config.content, context),
marks: callOrReturn(extension.config.marks, context), marks: callOrReturn(extension.config.marks, context),
group: callOrReturn(extension.config.group, context), group: callOrReturn(extension.config.group, context),
@@ -36,7 +60,6 @@ export default function getSchema(extensions: Extensions): Schema {
code: callOrReturn(extension.config.code, context), code: callOrReturn(extension.config.code, context),
defining: callOrReturn(extension.config.defining, context), defining: callOrReturn(extension.config.defining, context),
isolating: callOrReturn(extension.config.isolating, context), isolating: callOrReturn(extension.config.isolating, context),
tableRole: callOrReturn(extension.config.tableRole, context),
attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => { attrs: Object.fromEntries(extensionAttributes.map(extensionAttribute => {
return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }] return [extensionAttribute.name, { default: extensionAttribute?.attribute?.default }]
})), })),
@@ -61,7 +84,18 @@ export default function getSchema(extensions: Extensions): Schema {
const marks = Object.fromEntries(markExtensions.map(extension => { const marks = Object.fromEntries(markExtensions.map(extension => {
const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name) const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name)
const context = { options: extension.options } const context = { options: extension.options }
const extraMarkFields = markSchemaExtenders.reduce((fields, markSchemaExtender) => {
const extraFields = callOrReturn(markSchemaExtender, context, extension)
return {
...fields,
...extraFields,
}
}, {})
const schema: MarkSpec = cleanUpSchemaItem({ const schema: MarkSpec = cleanUpSchemaItem({
...extraMarkFields,
inclusive: callOrReturn(extension.config.inclusive, context), inclusive: callOrReturn(extension.config.inclusive, context),
excludes: callOrReturn(extension.config.excludes, context), excludes: callOrReturn(extension.config.excludes, context),
group: callOrReturn(extension.config.group, context), group: callOrReturn(extension.config.group, context),

View File

@@ -8,11 +8,12 @@ export { default as nodeInputRule } from './inputRules/nodeInputRule'
export { default as markInputRule } from './inputRules/markInputRule' export { default as markInputRule } from './inputRules/markInputRule'
export { default as markPasteRule } from './pasteRules/markPasteRule' export { default as markPasteRule } from './pasteRules/markPasteRule'
export { default as callOrReturn } from './utilities/callOrReturn'
export { default as mergeAttributes } from './utilities/mergeAttributes'
export { default as generateHTML } from './helpers/generateHTML' export { default as generateHTML } from './helpers/generateHTML'
export { default as getSchema } from './helpers/getSchema' export { default as getSchema } from './helpers/getSchema'
export { default as getHTMLFromFragment } from './helpers/getHTMLFromFragment' export { default as getHTMLFromFragment } from './helpers/getHTMLFromFragment'
export { default as getMarkAttributes } from './helpers/getMarkAttributes' export { default as getMarkAttributes } from './helpers/getMarkAttributes'
export { default as mergeAttributes } from './utilities/mergeAttributes'
export { default as isActive } from './helpers/isActive' export { default as isActive } from './helpers/isActive'
export { default as isMarkActive } from './helpers/isMarkActive' export { default as isMarkActive } from './helpers/isMarkActive'
export { default as isNodeActive } from './helpers/isNodeActive' export { default as isNodeActive } from './helpers/isNodeActive'

View File

@@ -3,14 +3,15 @@
* Otherwise it is returned directly. * Otherwise it is returned directly.
* @param value Function or any value. * @param value Function or any value.
* @param context Optional context to bind to function. * @param context Optional context to bind to function.
* @param props Optional props to pass to function.
*/ */
export default function callOrReturn(value: any, context?: any): any { export default function callOrReturn(value: any, context: any = undefined, ...props: any[]): any {
if (typeof value === 'function') { if (typeof value === 'function') {
if (context) { if (context) {
return value.bind(context)() return value.bind(context)(...props)
} }
return value() return value(...props)
} }
return value return value

View File

@@ -4,6 +4,7 @@ import {
mergeAttributes, mergeAttributes,
isCellSelection, isCellSelection,
findParentNodeClosestToPos, findParentNodeClosestToPos,
callOrReturn,
} from '@tiptap/core' } from '@tiptap/core'
import { import {
tableEditing, tableEditing,
@@ -64,6 +65,13 @@ declare module '@tiptap/core' {
fixTables: () => Command, fixTables: () => Command,
} }
} }
interface NodeConfig<Options> {
/**
* Table Role
*/
tableRole?: string | ((this: { options: Options }) => string),
}
} }
export const Table = Node.create<TableOptions>({ export const Table = Node.create<TableOptions>({
@@ -81,6 +89,14 @@ export const Table = Node.create<TableOptions>({
allowTableNodeSelection: false, allowTableNodeSelection: false,
}, },
extendNodeSchema(extension) {
const context = { options: extension.options }
return {
tableRole: callOrReturn(extension.config.tableRole, context),
}
},
content: 'tableRow+', content: 'tableRow+',
tableRole: 'table', tableRole: 'table',